gpt4 book ai didi

python - 在 Tensorflow 中有效地计算成对排序损失函数

转载 作者:太空宇宙 更新时间:2023-11-04 03:03:09 25 4
gpt4 key购买 nike

我目前正在实现 http://www.aclweb.org/anthology/P15-1061在 tensorflow 中。

我已经实现了成对排序损失函数(论文的第 2.5 节)如下:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index)
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size])
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

我不得不使用 tf.gather 而不是 tf.gather_nd,因为后者尚未使用梯度下降实现。我还必须将所有索引转换为正确的展平矩阵。

如果 tf.gather_nd 是用梯度下降实现的,我的代码将如下所示:

s_theta_y = tf.gather_nd(s_theta, y_t_index)
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index)
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1])

s_theta 是每个类标签的计算分数,如论文中所述。y_true_index 包含真实类别的索引,以便计算 s_theta_y。 y_neg_index 是所有负类的索引,它的维度是#class-1 或#class 是关系被分类为其他。

然而,有几个句子被归类为Other等,s_theta_y不存在,我们不应该在计算中考虑它。为了处理这种情况,我有一个常数因子 0,它取消了项,并且负类有相同的维度向量,我只是复制索引的一个随机值,因为最后,我们只对所有负类(而不是索引)中的最大值。

是否有更有效的方法来计算损失函数中的这些项?我的印象是使用 tf.gather 进行如此多的 reshape 非常慢

最佳答案

当然,这听起来像是 gather_nd 是您想要的,但在此处实现渐变之前,我会毫不犹豫地使用您的 reshape() 解决方案,因为 reshape() 实际上是免费的。

C++ implementation of the reshape() op看起来它做了很多工作,但这只是对形状信息的快速错误检查。 “工作”发生在第 90 行的 CopyFrom 中,这听起来可能很昂贵,但实际上只是一个指针副本(CopyFrom 调用 CopyFromInternal 来复制指针)。

这是完全有道理的:底层缓冲区只是 row-major order 中的数字平面数组,并且该排序不依赖于形状信息。出于同样的原因,像 tf.transpose() 这样的东西通常需要复制。

关于python - 在 Tensorflow 中有效地计算成对排序损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40330775/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com