gpt4 book ai didi

tensorflow - 在自定义 keras 损失中使用 keras 模型

转载 作者:行者123 更新时间:2023-12-01 14:12:25 25 4
gpt4 key购买 nike

我有一个名为 e 的常规 keras 模型,我想在我的自定义损失函数中比较 y_predy_true 的输出.

from keras import backend as K
def custom_loss(y_true, y_pred):
return K.mean(K.square(e.predict(y_pred)-e.predict(y_true)), axis=-1)

我收到错误:AttributeError: 'Tensor' object has no attribute 'ndim'这是因为 y_truey_pred 都是张量对象,keras.model.predict 希望传递一个 numpy.array.

知道如何在我的自定义损失函数中成功使用我的 keras.model 吗?

如果需要,我愿意获取指定层的输出,或者将我的 keras.model 转换为 tf.estimator 对象(或其他任何对象)。

最佳答案

首先,让我们尝试理解您收到的错误消息:

AttributeError: 'Tensor' object has no attribute 'ndim'

让我们看一下 Keras 文档并找到 predict Keras模型的方法。我们可以看到函数参数的说明:

x: the input data, as a Numpy array.

因此,该模型正在尝试获取 numpy 数组ndims 属性,因为它需要一个数组作为输入。另一方面,Keras 框架的自定义损失函数将 tensors 作为输入。所以,不要在其中编写任何 python 代码——它永远不会在评估期间执行。该函数只是在构建计算图时被调用。


好的,现在我们已经了解了该错误消息背后的含义,我们如何在自定义损失函数中使用 Keras 模型?简单的!我们只需要得到模型的评估图。

更新

global 关键字的使用是一种糟糕的编码习惯。此外,现在在 2020 年我们有更好的 functional API在 Keras 中,这使得对层的修改变得不必要。最好使用这样的东西:

from keras import backend as K

def make_custom_loss(model):
"""Creates a loss function that uses `model` for evaluation
"""
def custom_loss(y_true, y_pred):
return K.mean(K.square(model(y_pred) - model(y_true)), axis=-1)
return custom_loss

custom_loss = make_custom_loss(e)

已弃用

尝试这样的事情(仅适用于 Sequential 模型和非常旧的 API):

def custom_loss(y_true, y_pred):
# Your model exists in global scope
global e

# Get the layers of your model
layers = [l for l in e.layers]

# Construct a graph to evaluate your other model on y_pred
eval_pred = y_pred
for i in range(len(layers)):
eval_pred = layers[i](eval_pred)

# Construct a graph to evaluate your other model on y_true
eval_true = y_true
for i in range(len(layers)):
eval_true = layers[i](eval_true)

# Now do what you wanted to do with outputs.
# Note that we are not returning the values, but a tensor.
return K.mean(K.square(eval_pred - eval_true), axis=-1)

请注意,上面的代码未经测试。但是,无论实现如何,总体思路都将保持不变:您需要构建一个图,其中 y_truey_pred 将通过它流向最终操作。


关于tensorflow - 在自定义 keras 损失中使用 keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48726338/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com