gpt4 book ai didi

tensorflow - 如何使用 tensorflow 纠正 keras 的这个自定义损失函数?

转载 作者:行者123 更新时间:2023-11-30 08:39:08 29 4
gpt4 key购买 nike

我想编写一个自定义损失函数,该函数会惩罚低估权重的正目标值。它的工作方式类似于均方误差,唯一的区别是在所述情况下均方误差将乘以大于 1 的权重。

我是这样写的:

def wmse(ground_truth, predictions):
square_errors = np.square(np.subtract(ground_truth, predictions))
weights = np.ones_like(square_errors)
weights[np.logical_and(predictions < ground_truth, np.sign(ground_truth) > 0)] = 100
weighted_mse = np.mean(np.multiply(square_errors, weights))
return weighted_mse

但是,当我将其提供给 keras 中的顺序模型(以 tensorflow 作为后端)时:

model.compile(loss=wmse,optimizer='rmsprop')

我收到以下错误:

 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. 
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

回溯指向 wmse 中的这一行:

weights[np.logical_and(predictions < ground_truth, np.sign(ground_truth) > 0)] =  100

到目前为止,我从未使用过 kerastensorflow,因此,如果有人帮助我将此损失函数调整为 keras,我将不胜感激>/tensorflow 框架。我尝试将np.ological_and替换为tensorflow.logic_and,但无济于事,错误仍然存​​在。

最佳答案

正如 @nuric 提到的,您必须仅使用带有导数的 Keras/Tensorflow 操作来实现损失,因为这些框架无法通过其他操作(例如 numpy 操作)进行反向传播。

仅 Keras 的实现可能如下所示:

from keras import backend as K

def wmse(ground_truth, predictions):
square_errors = (ground_truth - predictions) ** 2
weights = K.ones_like(square_errors)
mask = K.less(predictions, ground_truth) & K.greater(K.sign(ground_truth), 0)
weights = K.switch(mask, weights * 100, weights)
weighted_mse = K.mean(square_errors * weights)
return weighted_mse

gt = K.constant([-2, 2, 1, -1, 3], dtype="int32")
pred = K.constant([-2, 1, 1, -1, 1], dtype="int32")
weights, loss = wmse(gt, pred)

sess = K.get_session()
print(loss.eval(session=sess))
# 100

关于tensorflow - 如何使用 tensorflow 纠正 keras 的这个自定义损失函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50877618/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com