gpt4 book ai didi

python - Keras 自定义损失函数 huber

转载 作者:行者123 更新时间:2023-12-04 03:57:12 24 4
gpt4 key购买 nike

<分区>

我使用 Keras 后端函数编写了 huber loss,效果很好:

def huber_loss(y_true, y_pred, clip_delta=1.0):
error = y_true - y_pred
cond = K.abs(error) < clip_delta

squared_loss = 0.5 * K.square(error)
linear_loss = clip_delta * (K.abs(error) - 0.5 * clip_delta)

return tf_where(cond, squared_loss, linear_loss)

但我需要一个更复杂的损失函数:

  1. 如果error <= A , 使用 squared_loss
  2. 如果A <= error < B ,使用linear_loss
  3. 如果error >= B ,使用了sqrt_loss

我是这样写的:

def best_loss(y_true, y_pred, A, B):
error = K.abs(y_true - y_pred)
cond = error <= A
cond2 = tf_logical_and(A < error, error <= B)

squared_loss = 0.5 * K.square(error)
linear_loss = A * (error - 0.5 * A)
sqrt_loss = A * np.sqrt(B) * K.sqrt(error) - 0.5 * A**2

return tf_where(cond, squared_loss, tf_where(cond2, linear_loss, sqrt_loss))

但是不行,有这个损失函数的模型不收敛,bug是什么?

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com