gpt4 book ai didi

debugging - 如何在 TensorFlow 和 Keras 的损失函数中打印中间变量?

转载 作者:行者123 更新时间:2023-12-04 01:05:17 27 4
gpt4 key购买 nike

我正在编写一个自定义目标来训练 Keras(带有 TensorFlow 后端)模型,但我需要调试一些中间计算。为简单起见,假设我有:

def custom_loss(y_pred, y_true):
diff = y_pred - y_true
return K.square(diff)

我找不到一种简单的方法来访问,例如,在训练期间访问中间变量 diff 或其形状。在这个简单的例子中,我知道我可以返回 diff 来打印它的值,但我的实际损失更复杂,我不能在没有编译错误的情况下返回中间值。

有没有一种简单的方法可以在 Keras 中调试中间变量?

最佳答案

据我所知,这在 Keras 中无法解决,因此您必须求助于后端特定的功能。两者 TheanoTensorFlowPrint作为身份节点的节点(即,它们返回输入节点)并具有打印输入(或输入的某些张量)的副作用。

Theano 示例:

diff = y_pred - y_true
diff = theano.printing.Print('shape of diff', attrs=['shape'])(diff)
return K.square(diff)

TensorFlow 示例:
diff = y_pred - y_true
diff = tf.Print(diff, [tf.shape(diff)])
return K.square(diff)

请注意,这仅适用于中间值。 Keras 期望传递给其他层的张量具有特定属性,例如 _keras_shape .后端处理的值,即通过 Print ,通常没有那个属性。要解决这个问题,您可以将调试语句包装在 Lambda 中。层为例。

关于debugging - 如何在 TensorFlow 和 Keras 的损失函数中打印中间变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42233963/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com