gpt4 book ai didi

python - 根据 TensorFlow 中的 python 列表过滤张量

转载 作者:行者123 更新时间:2023-12-01 00:45:35 25 4
gpt4 key购买 nike

我有一个tf.int64类型的张量a。我想根据给定的 python 列表过滤掉这个张量。
例如 -

l = [1,2,3]
a = tf.constant([1,2,3,4], dtype=tf.int64)

需要一个值为 1,2,3(4 除外)的张量。也就是在l的基础上过滤掉a。我如何在 TensorFlow 中执行此操作?

最佳答案

您可以使用tf.sets.set_intersection:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
# tf.sets.intersection in more recent versions
b = tf.sets.set_intersection(tf.expand_dims(a, 0), tf.expand_dims(l, 0))
b = tf.squeeze(tf.sparse.to_dense(b), 0)
print(sess.run(b))
# [1 2 3]

但是在很多情况下这可能不会达到您想要的效果。如果存在重复元素,它将丢弃它们,并且也会对输出进行排序。更一般地说,您可以这样做:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)
b = tf.boolean_mask(a, m)
print(sess.run(b))
# [1 2 3]

这是一个二次比较,但我认为没有比 np.isin 更好的了。在 TensorFlow 中。

关于python - 根据 TensorFlow 中的 python 列表过滤张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57006393/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com