gpt4 book ai didi

python - 在 Keras 中使用 Earth Mover Loss 方法和输入参数数据类型

转载 作者:行者123 更新时间:2023-12-03 23:47:33 25 4
gpt4 key购买 nike

我在 Keras/Tensrflow 中找到了 Earth Mover Loss 的代码。我想计算给图像的分数的损失,但在我了解下面给出的 Earth Mover Loss 的工作之前,我无法做到这一点。有人可以描述一下代码中发生了什么。

模型或输出层的最后一层是这样的:

out = Dense(10,activation='softmax')(x)

这个方法的输入类型应该是什么。我有我的 y_labels1.2,4.9 的形式等等等等。我想将它与 Keras/Tensorflow 一起使用
def earth_mover_loss(y_true, y_pred):
cdf_true = K.cumsum(y_true, axis=-1)
cdf_pred = K.cumsum(y_pred, axis=-1)
emd = K.sqrt(K.mean(K.square(cdf_true - cdf_pred), axis=-1))
return K.mean(emd)

最佳答案

您可以将 EML 视为 CDF 概率函数的一种 RMSE

给定 N 个类别,您所需要的只是每个样本的归一化概率分数。在神经网络领域,这是通过 softmax 激活函数作为输出层来实现的

EML 只是比较预测与现实的 CDF

在一个有 10 个类的分类问题中,对于单个样本,我们可以有这些数组

y_true = [0,0,0,1,0,0,0,0,0,0] #样本属于第4类

y_pred = [0.1,0,0,0.9,0,0,0,0,0,0] # softmax 层的概率输出

我们对它们计算 CDF 并得到以下分数:

CDF_y_true = [0,0,0,1,1,1,1,1,1,1]

CDF_y_pred = [0.1,0.1,0.1,1,1,1,1,1,1,1]

如上所述,EML 计算此 CDF 上的 RMSE

y_true = np.asarray([0.,0.,0.,1.,0.,0.,0.,0.,0.,0.])
y_pred = np.asarray([0.1,0.,0.,0.9,0.,0.,0.,0.,0.,0.])

cdf_true = K.cumsum(y_true, axis=-1)
cdf_pred = K.cumsum(y_pred, axis=-1)
emd = K.sqrt(K.mean(K.square(cdf_true - cdf_pred), axis=-1))

在TID2013上Google的NIMA Paper的具体案例中,N=10,标签以浮点分数的形式表示。为了使用 EML 训练网络,需要遵循以下步骤:
  • 数字化 10 个区间的 float 分数
  • 对标签进行单热编码以获得 softmax 概率并最小化 EML

  • 在训练结束时,我们的神经网络能够在给定的图像上生成每个类别的概率分数。
    我们必须将此分数转换为具有论文中定义的相关标准偏差的平均质量分数。
    为此,我们遵循论文中定义的程序

    bins = [1,2,3,4,5,6,7,8,9,10]

    y_pred = [0.1,0,0,0.9,0,0,0,0,0,0] # softmax 层的概率输出

    mu_score = sum(bins*y_pred) = 1*0.1 + 2*0 + 3*0 + 4*0.9 + ... + 10*0

    sigma_score = sum(((bins - mu_score)**2)*y_pred)**0.5
    bins = np.arange(1,11)
    y_pred = np.asarray([0.1,0.,0.,0.9,0.,0.,0.,0.,0.,0.])

    mu_score = np.sum(bins*y_pred)
    std_score = np.sum(((bins - mu_score)**2)*y_pred)**0.5

    关于python - 在 Keras 中使用 Earth Mover Loss 方法和输入参数数据类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61673551/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com