gpt4 book ai didi

python - 如何检查 Tensor 的值是否包含在其他张量中?

转载 作者:行者123 更新时间:2023-12-01 01:38:01 25 4
gpt4 key购买 nike

我在从其他张量中查找值时遇到问题

它类似于以下问题:(URL:How to find a value in tensor from other tensor in Tensorflow)

上一题是询问输入张量x[i],y[i]是否包含在输入张量label_x,<强>label_y

这是上一个问题的示例:

Input Tensor
s_idx = (1, 3, 5, 7)
e_idx = (3, 4, 5, 8)

label_s_idx = (2, 2, 3, 6)
label_e_idx = (2, 3, 4, 8)

问题是给output[i]赋值1如果对于某些 j 满足 s_idx[i] == label_s_idx[j] 且 e_idx[i] == label_s_idx[j] 对于某些 j 满足。

因此,在上面的示例中,输出张量为

output = (0, 1, 0, 0)

因为 (s_idx[1] = 3, e_idx[1] = 4) 与 (label_s_idx[2] = 3) 相同, label_e_idx[2] = 4)

(s_idx, e_idx) 没有重复值,而 (label_s_idx, label_e_idx) 则有重复值。

因此,假设以下输入示例是不可能的:

s_idx = (2, 2, 3, 3)
e_idx = (2, 3, 3, 3)

因为 (s_idx[2] = 3, e_idx[2] = 3) 与 (s_idx[3] = 3、e_idx[3] = 3)。

在这个问题中我想稍微改变一下,就是向输入张量添加另一个值:

Input Tensor
s_idx = (1, 3, 5, 7)
e_idx = (3, 4, 5, 8)

label_s_idx = (2, 2, 3, 6)
label_e_idx = (2, 3, 4, 8)
label_score = (1, 3, 2, 3)

*label_score 张量中没有 0 值

更改后的问题中的任务定义如下:

问题是如果s_idx[i] == label_s_idx[j]e_idx[i] == label_s_idx[j,则为output_2[i]提供label_score[j]的值]对某些 j 感到满意。

因此,output_2应该是这样的:

output = (0, 1, 0, 0)  // It is same as previous problem
output_2 = (0, 2, 0, 0)

如何在 Python 中的 Tensorflow 上进行这样的编码?

最佳答案

这是一个可能的解决方案:

import tensorflow as tf

s_idx = tf.placeholder(tf.int32, [None])
e_idx = tf.placeholder(tf.int32, [None])
label_s_idx = tf.placeholder(tf.int32, [None])
label_e_idx = tf.placeholder(tf.int32, [None])
label_score = tf.placeholder(tf.int32, [None])

# Stack inputs for comparison
se_idx = tf.stack([s_idx, e_idx], axis=1)
label_se_idx = tf.stack([label_s_idx, label_e_idx], axis=1)
# Compare every pair to each other and find matches
cmp = tf.equal(se_idx[:, tf.newaxis, :], label_se_idx[tf.newaxis, :, :])
matches = tf.reduce_all(cmp, axis=2)
# Find the position of the matches
match_pos = tf.argmax(tf.cast(matches, tf.int8), axis=1)
# For those positions where a match was found take the corresponding score
output = tf.where(tf.reduce_any(matches, axis=1),
tf.gather(label_score, match_pos),
tf.zeros_like(label_score))

# Test
with tf.Session() as sess:
print(sess.run(output, feed_dict={s_idx: [1, 3, 5, 7],
e_idx: [3, 4, 5, 8],
label_s_idx: [2, 2, 3, 6],
label_e_idx: [2, 3, 4, 8],
label_score: [1, 3, 2, 3]}))
# >>> [0 2 0 0]

它将每对值相互比较,因此成本是输入大小的二次方。另外,tf.argmax用于查找匹配位置的索引,如果存在多个可能的索引,则可能会不确定地返回其中任何一个。

关于python - 如何检查 Tensor 的值是否包含在其他张量中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52199843/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com