gpt4 book ai didi

tensorflow - 自定义指标访问 X 输入数据

转载 作者:行者123 更新时间:2023-12-04 13:00:41 25 4
gpt4 key购买 nike

我想为拼写校正模型编写一个自定义指标,该模型可以计算以前不正确的正确替换字母。并且应该计算错误替换的先前正确的字母。

这就是我需要访问 x_input 数据的原因。不幸的是,默认情况下只有 y_true 和 y_pred 可访问。是否有解决方法来获得匹配的 x_input?

是:

def custom_metric(y_true, y_pred):

通缉:
def custom_metric(x_input, y_true, y_pred):

最佳答案

def custom_loss(x_input):
def loss_fn(y_true, y_pred):
# Use your x_input here directly
return #Your loss value
return loss_fn

model = # Define your model
model.compile(loss=custom_loss(x_input))
# Values of y_true and y_pred will be passed implicitly by Keras

请记住 x_input在训练模型时,所有批次的输入都将具有相同的值。

编辑 :

既然您需要 x_input数据 只有用于在损失函数期间进行估计的每个批次中,并且您有自己的自定义损失函数,为什么不通过 x_input作为标签。像这样的东西:
model.fit(x=x_input, y=x_input)
model.compile(loss=custom_loss())

def custom_loss(y_true, y_pred):
# y_true corresponds to x_input data

如果你需要 x_input 并且你需要传递一些其他数据,你可以这样做:
model.fit(x=x_input, y=[x_input, other_data])
model.compile(loss=custom_loss())

你只需要解耦 y_true中的数据现在。

关于tensorflow - 自定义指标访问 X 输入数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57935189/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com