gpt4 book ai didi

python - Keras 一个操作在 train_on_batch 时对梯度有 None

转载 作者:行者123 更新时间:2023-11-28 19:03:04 31 4
gpt4 key购买 nike

Google Colab 重现错误 None_for_gradient.ipynb

我需要一个自定义损失函数,其中根据模型输入计算值,这些输入不是默认值 (y_true, y_pred)。预测方法适用于生成的架构,但是当我尝试使用 train_on_batch 时,出现以下错误。

ValueError: An operation has None for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

我的自定义损失函数(下方)基于此示例 image_ocr.py#L475 ,在 Colab 链接中有另一个基于此解决方案的示例 Custom loss function y_true y_pred shape mismatch #4781 ,它也会产生同样的错误:

from keras import backend as K
from keras import losses
import keras
from keras.models import TimeDistributed, Dense, Dropout, LSTM

def my_loss(args):
input_y, input_y_pred, y_pred = args
return keras.losses.binary_crossentropy(input_y, input_y_pred)

def generator2():
input_noise = keras.Input(name='input_noise', shape=(40, 38), dtype='float32')
input_y = keras.Input(name='input_y', shape=(1,), dtype='float32')
input_y_pred = keras.Input(name='input_y_pred', shape=(1,), dtype='float32')
lstm1 = LSTM(256, return_sequences=True)(input_noise)
drop = Dropout(0.2)(lstm1)
lstm2 = LSTM(256, return_sequences=True)(drop)
y_pred = TimeDistributed(Dense(38, activation='softmax'))(lstm2)

loss_out = keras.layers.Lambda(my_loss, output_shape=(1,), name='my_loss')([input_y, input_y_pred, y_pred])

model = keras.models.Model(inputs=[input_noise, input_y, input_y_pred], outputs=[y_pred, loss_out])
model.compile(loss={'my_loss': lambda y_true, y_pred: y_pred}, optimizer='adam')

return model

g2 = generator2()
noise = np.random.uniform(0,1,size=[10,40,38])
g2.train_on_batch([noise, np.ones(10), np.zeros(10)], noise)

我需要帮助来验证是哪个操作产生了这个错误,因为据我所知,keras.losses.binary_crossentropy 是可微分的。

最佳答案

我认为原因是input_y和input_y_pred都是keras Input,你的损失函数是用这两个张量计算的,它们没有与模型参数绑定(bind),所以损失函数没有给你的模型梯度

关于python - Keras 一个操作在 train_on_batch 时对梯度有 None,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50434710/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com