gpt4 book ai didi

keras - Keras 的中心损失

转载 作者:行者123 更新时间:2023-12-02 09:11:16 25 4
gpt4 key购买 nike

我想实现[http://ydwen.github.io/papers/WenECCV16.pdf]中解释的中心损失]在喀拉斯

我开始创建一个具有 2 个输出的网络,例如:

inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd',
loss=['categorical_crossentropy', 'center_loss'],
metrics=['accuracy'], loss_weights=[1., 0.2])

首先,这样做是一个好的方法吗?

其次,我不知道如何在keras中实现center_loss。 Center_loss 看起来像均方误差,但不是将值与固定标签进行比较,它将值与每次迭代时更新的数据进行比较。

感谢您的帮助

最佳答案

对我来说,您可以按照以下步骤实现该层:

  1. 编写一个自定义层 ComputeCenter

    • 接受两个输入:i)。 groudtruth 标签 y_true (不是单热编码,只是整数)和 ii)。预测成员资格 y_pred

    • 包含一个大小为 num_classes x num_feats 数组的查找表 W 作为可训练权重(请参阅 BatchNormalization Layer),W[j] 是第 j 类特征的移动平均值的占位符。

    • 计算论文中指定的中心损耗。

    • 输出结果距离数组D
  2. 要计算中心损失,您需要

    • 我)。根据 y_true[k]=j 使用 y_pred[k] 更新 W[j]
    • ii)。检索 y_true[k]=j
    • 的样本 y_pred[k] 的中心特征 c_true[k]=W[j] >
    • iii) 计算 y_predc_true 之间的距离。
    • 这里c_true[k] = W[j]k是样本索引,j是y_pred[k]。
  3. 使用model.add_loss()来计算此损失。请注意,请勿在 model.compile( loss = ... ) 中添加此损失。

最后,如果需要,您可以在 center-loss 中添加一些损失系数。

关于keras - Keras 的中心损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40173892/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com