gpt4 book ai didi

python - "LookupError: Function ` __class__ ` does not exist."使用 tf.function 时

转载 作者:行者123 更新时间:2023-11-30 09:42:11 26 4
gpt4 key购买 nike

我正在尝试使用 WebAssembly 二进制文件中抽象的自定义损失函数来覆盖 keras/tf2.0 损失函数。这是相关代码。

@tf.function
def custom_loss(y_true, y_pred):
return tf.constant(instance.exports.loss(y_true, y_pred))

我就是这样使用的

model.compile(optimizer_obj, loss=custom_loss)
# and then model.fit(...)

我不完全确定 tf2.0 eager execution 是如何工作的,因此任何有关此的见解都会有用。

我认为 instance.exports.loss 函数与该错误无关,但是如果您确定其他一切正常,请告诉我,我将分享更多详细信息。

这是堆栈跟踪和实际错误: https://pastebin.com/6YtH75ay

最佳答案

首先,您不需要使用@tf.function来定义自定义损失。

我们可以愉快地(如果毫无意义的话)做这样的事情:

def custom_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)

model.compile(optimizer='adam',
loss=custom_loss,
metrics=['accuracy'])

只要我们在 custom_loss 中使用的所有操作都可以通过 tensorflow 微分

因此,您可以删除 @tf.function 装饰器,但我怀疑您会遇到如下错误消息:

[Some trace info] - Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

因为 Tensorflow 无法在 WebAssembly 二进制文件中找到函数的梯度。损失函数内部的所有内容都必须是 tensorflow 可以理解并计算梯度的东西,否则它将无法优化以获得更低的损失值。

也许最好的方法是使用 tensorflow 可以计算梯度的操作来复制 instance.exports.loss 内部的功能,而不是尝试直接引用它?

关于python - "LookupError: Function ` __class__ ` does not exist."使用 tf.function 时,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57288243/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com