gpt4 book ai didi

python - 在 TensorFlow 中,如何断言列表的值在某个集合中?

转载 作者:太空宇宙 更新时间:2023-11-04 02:14:00 26 4
gpt4 key购买 nike

我有一个一维 tf.uint8 张量 x 并且想断言该张量内的所有值都在集合 s 中定义。 s 在图形定义时是固定的,因此它不是动态计算的张量。

在普通的 Python 中,我想做某事。像下面这样:

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
assert all(el in s for el in x), "This should fail, as 5 is not in s"

我知道我可以使用 tf.Assert对于断言部分,但我正在努力定义条件部分(el in s)。执行此操作的最简单/最规范的方法是什么?

较旧的答案 Determining if A Value is in a Set in TensorFlow对我来说还不够:首先,写下来和理解起来很复杂,其次,它使用的是广播的tf.equal,这在计算方面比适当的基于集合的检查更昂贵.

最佳答案

一个简单的方法可能是这样的:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# Check every value in x against every value in s
xs_eq = tf.equal(x_t[:, tf.newaxis], s_t)
# Check every element in x is equal to at least one element in s
assert_op = tf.Assert(tf.reduce_all(tf.reduce_any(xs_eq, axis=1)), [x_t])
with tf.control_dependencies([assert_op]):
# Use x_t...

这将创建一个大小为 (len(x), len(s)) 的中间张量。如果这是有问题的,您还可以将问题拆分为独立的张量,例如:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
# Count where each x matches each s
x_in_s = [tf.cast(tf.equal(x_t, si), tf.int32) for si in s]
# Add matches and check there is at least one match per x
assert_op = tf.Assert(tf.reduce_all(tf.add_n(x_in_s) > 0), [x_t])

编辑:

实际上,既然你说你的值是 tf.uint8,你可以使用 bool 数组让事情变得更好:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# One-hot vectors of values included in x and s
x_bool = tf.scatter_nd(tf.cast(x_t[:, tf.newaxis], tf.int32),
tf.ones_like(x_t, dtype=tf.bool), [256])
s_bool = tf.scatter_nd(tf.cast(s_t[:, tf.newaxis], tf.int32),
tf.ones_like(s_t, dtype=tf.bool), [256])
# Check that all values in x are in s
assert_op = tf.Assert(tf.reduce_all(tf.equal(x_bool, x_bool & s_bool)), [x_t])

这需要线性时间和常量内存。

编辑 2:虽然最后一种方法理论上在这种情况下是最好的,但做几个快速基准测试时,我只能看到当我达到数十万个元素时性能上的显着差异,无论如何这三个使用 tf.uint8 仍然非常快。

关于python - 在 TensorFlow 中,如何断言列表的值在某个集合中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53061284/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com