gpt4 book ai didi

python - 一次获取 Tensorflow 中多个元素的索引

转载 作者:太空宇宙 更新时间:2023-11-04 04:00:21 24 4
gpt4 key购买 nike

我是 Tensorflow 的新手。

我有一个问题。

这里是一维数组。

 values = [101,103,105,109,107]

target_values = [105, 103]

我想立即从 values 中获取关于 target_values 的索引。

从上面的示例中提取的索引将显示在下面。

indices = [2, 1]

当我使用 tf.map_fn 函数时。这个问题很容易解决。

# if you do not change data type from int64 to int32. TypeError will riase
values = tf.cast(tf.constant([100, 101, 102, 103, 104]), tf.int64)
target_values = tf.cast(tf.constant([100, 101]), tf.int64)
indices = tf.map_fn(lambda x: tf.where(tf.equal(values, x)), target_values)

谢谢!

最佳答案

假设 target_values 中的所有值都在 values 中,这是实现此目的的一种简单方法(TF 2.x,但该函数对于 1.x 应该同样有效。 x):

import tensorflow as tf

values = [101, 103, 105, 109, 107]
target_values = [105, 103]

# Assumes all values in target_values are in values
def find_in_array(values, target_values):
values = tf.convert_to_tensor(values)
target_values = tf.convert_to_tensor(target_values)
# stable=True if there may be repeated elements in values
# and you want always first occurrence
idx_s = tf.argsort(values, stable=True)
values_s = tf.gather(values, idx_s)
idx_search = tf.searchsorted(values_s, target_values)
return tf.gather(idx_s, idx_search)

print(find_in_array(values, target_values).numpy())
# [2 1]

关于python - 一次获取 Tensorflow 中多个元素的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58481332/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com