gpt4 book ai didi

python - 使用索引提取张量的行和列

转载 作者:太空宇宙 更新时间:2023-11-04 11:12:43 26 4
gpt4 key购买 nike

我有一个二维张量 x,它是一个占位符 x = tf.placeholder(tf.int32, shape=[None, 100])。现在在图中,我想随机提取 x 的 50% 行。这是我尝试过的:

idxs = tf.range(tf.shape(x)[0])
ridxs = tf.random_shuffle(idxs)[: int(tf.shape(x)[0] * 0.5)]
x = tf.gather(x, ridxs, axis=0)

它抛出以下错误:

Traceback (most recent call last):
File "main_gcn_random_sampling.py", line 76, in <module>
model = GCN_RANDOM_SAMPLING(FLAGS.learning_rate, num_input, num_classes, hidden_dimensions=hidden_sizes, act=tf.nn.relu)
File "/home/tiendh/Projects/EPooling/gcn_random_sampling.py", line 63, in __init__
ridxs = tf.random_shuffle(idxs)[: int(tf.shape(x)[0] * 0.5)]
TypeError: unsupported operand type(s) for *: 'Tensor' and 'float'

我可以为预定义的索引声明另一个占位符,但它看起来很难看。还有其他建议吗?

最佳答案

这几乎是正确的,但是:

  1. 在乘以 0.5 之前,您必须转换为 tf.shape(x)[0] 才能 float 。
  2. 您不能使用 int 来转换张量数据类型,您需要使用 tf.cast .
idxs = tf.range(tf.shape(x)[0])
num_half_idx = tf.cast(tf.cast(tf.shape(x)[0], tf.float32) * 0.5, tf.int32)
ridxs = tf.random_shuffle(idxs)[: num_half_idx]
x = tf.gather(x, ridxs, axis=0)

请注意,如果您总是获得 50% 的行,您也可以简单地执行以下操作:

idxs = tf.range(tf.shape(x)[0])
ridxs = tf.random_shuffle(idxs)[: tf.shape(x)[0] // 2]
x = tf.gather(x, ridxs, axis=0)

关于python - 使用索引提取张量的行和列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57819633/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com