gpt4 book ai didi

tensorflow - 如何实现标记嵌入的中心损失和其他运行平均值

转载 作者:行者123 更新时间:2023-12-03 00:37:06 24 4
gpt4 key购买 nike

最近的一篇论文 (here) 引入了一种称为中心损失的二次损失函数。它基于批处理中嵌入之间的距离以及每个相应类的运行平均嵌入。 TF Google 群组 (here) 中已经就如何计算和更新此类嵌入中心进行了一些讨论。我在下面的答案中整理了一些代码来生成类平均嵌入。

这是最好的方法吗?

最佳答案

对于像中心损失这样的情况来说,之前发布的方法太简单了,随着模型变得更加精细,嵌入的预期值会随着时间的推移而变化。这是因为之前的中心查找例程对自启动以来的所有实例进行平均,因此跟踪预期值的变化非常缓慢。相反,移动窗口平均值是首选。指数移动窗口变体如下:

def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")

label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch


with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()

关于tensorflow - 如何实现标记嵌入的中心损失和其他运行平均值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40162361/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com