gpt4 book ai didi

python - 如何将损失函数中的变量存储到实例变量中

转载 作者:行者123 更新时间:2023-11-30 09:34:53 25 4
gpt4 key购买 nike

我正在将 Keras 与 Tensorflow 结合使用。因为我想创建LSTM-CRF model ,我使用 tf.contrib.crf.crf_log_likelihood 定义了自己的损失函数:

def loss(self, y_true, y_pred):
sequence_lengths = ... # calc from y_true
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(y_pred, y_true, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood)
self.transition_params = transition_params

return loss

如您所知,CRF 在预测阶段需要转换参数。因此,我将 transition_params 存储到实例变量 self.transition_params 中。

问题是 self.transition_params 在小批量期间从未更新过。据我观察,编译模型时似乎只存储一次。

有没有办法将损失函数中的变量存储到Keras的实例变量中?

最佳答案

问题是函数签名错误 tf.contrib.crf.crf_log_likelihood ,您需要将 transition_params 与当前的转换参数一起传递。以下更改将解决相同问题。

log_likelihood, transition_params = 
tf.contrib.crf.crf_log_likelihood(y_pred, y_true, sequence_lengths,
transition_params=self.transition_params)

关于python - 如何将损失函数中的变量存储到实例变量中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45666991/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com