gpt4 book ai didi

python - 形状为 T1 = N*D、T2 = M*D 的两个 tensorflow 张量; M < N。T1 在 T2 中有行。查找 T1 张量中 T2 中每一行的行索引

转载 作者:太空宇宙 更新时间:2023-11-03 21:18:46 26 4
gpt4 key购买 nike

我有两个张量T1(N * D维度)和T2(M * D维度)(M小于N)。 T2 行保证位于 T1 中。对于 T2 中的每一行,有没有办法找到该行匹配的 T1 索引?我能够使用急切执行来解决问题。

import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
x = tf.random_normal([15,3])
y = x[:2] # first two entries
y= tf.concat([y,x[8:9]], 0)
output = []
for row in x:
if row.numpy() in y.numpy():
output.append(True)
else:
output.append(False)

有人可以在不急于执行的情况下提供执行帮助吗?如果T1和T2是批处理,我们如何执行相同的操作?即 T1 - B * N * D 和 T2 - B * M * D

附注我们如何在 Tensorflow 中搜索行?

最佳答案

以下是您可以执行此操作的方法:

import tensorflow as tf

def find_row_indices(t1, t2):
# Compare every pair of rows
eq = tf.equal(tf.expand_dims(t1, -3), tf.expand_dims(t2, -2))
# Find where all the elements in two rows match
matches = tf.reduce_all(eq, axis=-1)
# Find indices where match occurs
idx = tf.argmax(tf.cast(matches, tf.uint8), axis=-1)
# Find where there has been no match
has_match = tf.reduce_any(matches, axis=-1)
# Return match index of -1 if no match found
return tf.where(has_match, idx, -tf.ones_like(idx))

# Test
with tf.Graph().as_default():
tf.set_random_seed(100)
x = tf.random_normal([15, 3])
y = x[:2]
y = tf.concat([y, x[8:9]], 0)
idx = find_row_indices(x, y)
with tf.Session() as sess:
print(sess.run(idx))
# [0 1 8]

它具有二次方的空间和内存成本,因为它将每对行相互比较,因此在某些情况下拥有两个非常大的输入可能会出现问题。另外,如果有多个匹配索引,该方法不保证返回哪一个。

编辑:上面的函数也可以用于具有更多初始维度的数组。例如:

import tensorflow as tf

with tf.Graph().as_default():
tf.set_random_seed(100)
x = tf.random_normal([4, 15, 3])
y = tf.gather_nd(x, [[[ 0, 3], [ 0, 6], [ 0, 2]],
[[ 1, 10], [ 1, 5], [ 1, 12]],
[[ 2, 8], [ 2, 1], [ 2, 0]],
[[ 3, 9], [ 3, 14], [ 3, 4]]])
idx = find_row_indices(x, y)
with tf.Session() as sess:
print(sess.run(idx))
# [[ 3 6 2]
# [10 5 12]
# [ 8 1 0]
# [ 9 14 4]]

关于python - 形状为 T1 = N*D、T2 = M*D 的两个 tensorflow 张量; M < N。T1 在 T2 中有行。查找 T1 张量中 T2 中每一行的行索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54481099/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com