gpt4 book ai didi

tensorflow - 如何在 Call() 方法中使用位置参数保存 keras 子类模型?

转载 作者:行者123 更新时间:2023-12-03 17:23:25 24 4
gpt4 key购买 nike

import tensorflow as tf

class MyModel(tf.keras.Model):

def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

@tf.function
def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
x = self.dense1(enc_input)
return self.dense2(x)

x = tf.random.normal((10,20))

model = MyModel()

y = model(x, x, False, None, None, None)

tf.keras.models.save_model(model, '/saved')
当我尝试保存模型时,即使我传递了所有参数,也会引发错误。 tf__call() missing 4 required positional arguments: 'training', 'mask1', 'mask2', and 'mask3'如何保存整个模型而不仅仅是保存权重?

最佳答案

我认为进行以下更改会起作用

 #def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
def call(self, enc_input, dec_input, training=False, mask1=None, mask2=None, mask3=None):
挖了之后,我想有一个 sanity check这发生在函数参数上,如果未指定位置参数,则 x 之后的参数将被视为 **kwargs 参数(我对此不太确定)。
但是为此,如果您不想设置参数的默认映射,您可以解压缩它们,以便每个参数都进入其相应的位置,如下所示: y = model(*[x,x,False,None,None,None])

关于tensorflow - 如何在 Call() 方法中使用位置参数保存 keras 子类模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64425861/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com