gpt4 book ai didi

tensorflow - 在 tensorflow 中展平批处理

转载 作者:行者123 更新时间:2023-12-03 10:10:12 25 4
gpt4 key购买 nike

我有一个形状为 [None, 9, 2] 的 tensorflow 的输入(其中 None 是批处理)。

要对其执行进一步的操作(例如 matmul),我需要将其转换为 [None, 18]形状。怎么做?

最佳答案

您可以使用 tf.reshape() 轻松完成,而无需知道批量大小。

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
-1最后一行表示整列,无论运行时的批处理大小如何。您可以在 tf.reshape() 中看到它.

更新:形状 = [无,3,无]

谢谢@kbrose。对于超过 1 个维度未定义的情况,我们可以使用 tf.shape()tf.reduce_prod()或者。
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])

tf.shape() 返回一个可以在运行时评估的形状张量。 tf.get_shape() 和 tf.shape() 的区别可见 in the doc .

我也在另一个 .contrib.layers.flatten() 中尝试过。第一种情况最简单,但不能处理第二种情况。

关于tensorflow - 在 tensorflow 中展平批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36668542/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com