gpt4 book ai didi

python - Tensorflow:计算标签的逆计数

转载 作者:太空宇宙 更新时间:2023-11-04 00:07:42 27 4
gpt4 key购买 nike

我有一个形状为 (b,n) 的张量 labels,其值在 [0,1,2,3,4,5] 中。我想创建一个形状为 (b,n) 的张量 weights,它在位置 (i,j) 处携带整数 labels[i,j]< 的次数的倒数 出现在标签中。

工作示例代码:

import tensorflow as tf
import numpy as np
tf.InteractiveSession()
labels=tf.convert_to_tensor(np.array([[1,0,0,1,2,4],[2,2,2,4,2,1]]), dtype=tf.int32)
weights=tf.ones_like(labels, dtype=tf.float32)
bc=tf.bincount(labels, minlength=6, maxlength=6)
for i in range(6):
cur_count = 1.0/(1e-10+tf.cast(bc[i], tf.float32))
count_tensor = tf.ones_like(labels, dtype=tf.float32)*cur_count
weights = tf.where(tf.equal(labels,i), count_tensor, weights)
weights.eval()
# array([[0.3333, 0.5, 0.5, 0.3333, 0.2, 0.5],
# [0.2, 0.2, 0.2, 0.5, 0.2, 0.3333]], dtype=float32)

因此,例如,标签 1labels 张量中出现了三次,因此在 weights 中,值 1/3 出现在每个位置,其中 1labels 中。

现在我不喜欢此代码的是 tf.bincount 在我的 tensorflow 版本 (1.4.0) 中无法在 GPU 上运行,而且我无法更新。另外,我不确定 tensorflow 如何处理 for 循环以及由此产生了多少开销。

我想我的问题有更优雅的解决方案。有什么想法吗?

最佳答案

关于循环,您可以将其替换为对 tf.gather 的调用:

import tensorflow as tf
import numpy as np

tf.InteractiveSession()
labels = tf.convert_to_tensor(
np.array([[1, 0, 0, 1, 2, 4], [2, 2, 2, 4, 2, 1]]), dtype=tf.int32)
bc = tf.bincount(labels, minlength=6, maxlength=6)
weights = tf.gather(1.0 / (1e-10 + tf.cast(bc, tf.float32)), labels)
print(weights.eval())

输出:

[[0.33333334 0.5        0.5        0.33333334 0.2        0.5       ]
[0.2 0.2 0.2 0.5 0.2 0.33333334]]

关于 tf.bincount 是 CPU-only,目前似乎并非如此。事实上,GPU 实现似乎已经可用 since v1.5.0 .

如果你想要一个替代的实现,你可以这样做:

import tensorflow as tf
import numpy as np

tf.InteractiveSession()
labels = tf.convert_to_tensor(
np.array([[1, 0, 0, 1, 2, 4], [2, 2, 2, 4, 2, 1]]), dtype=tf.int32)
eq = tf.equal(labels[:, :, tf.newaxis], tf.range(6, dtype=labels.dtype))
bc = tf.reduce_sum(tf.cast(eq, tf.float32), axis=[0, 1])
weights = tf.gather(1.0 / (1e-10 + tf.cast(bc, tf.float32)), labels)
print(weights.eval())
# Same output

但是 tf.bincount 可能比这更有效。

关于python - Tensorflow:计算标签的逆计数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53497357/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com