gpt4 book ai didi

python - 广播具有动态形状的 tf.matmul

转载 作者:太空宇宙 更新时间:2023-11-03 21:24:54 24 4
gpt4 key购买 nike

我想在等级 2 和等级 3 的两个张量之间广播 tf.matmul 运算,其中一个包含“未知”形状的维度(基本上是特定维度中的“无”值) )。

问题是动态尺寸 tf.reshapetf.broadcast_to 似乎不起作用。

x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
w = tf.ones([10, 20])
y = x @ w
with tf.Session() as sess:
r1 = sess.run(y, feed_dict={x: np.ones([3, 5, 10])})
r2 = sess.run(y, feed_dict={x: np.ones([7, 5, 10])})

以上面的代码为例。在这种情况下,我喂食两批不同的饲料,每批分别含有 3 个和 7 个元素。我希望 r1r2 是这些批处理中的 3 个或 7 个元素中的每一个矩阵乘以 w 的结果。因此,r1r2 的结果形状分别为 (3, 5, 20) 和 (7, 5, 20),但我得到的是:

ValueError: Shape must be rank 2 but is rank 3 for 'matmul' (op: 'MatMul') with input shapes: [?,5,10], [10,20].

最佳答案

w 可以扩展为 rank-3 张量,其批量大小等于输入的批量大小。然后就可以进行matmul运算了

x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
w = tf.ones([10, 20])

number_batches = tf.shape(x)[0]
w = tf.tile(tf.expand_dims(w, 0), [number_batches, 1, 1])
y = x @ w
with tf.Session() as sess:
print(sess.run(y, feed_dict={x: np.ones([2, 5, 10])}))
print(sess.run(y, feed_dict={x: np.ones([3, 5, 10])}))

实时代码here

关于python - 广播具有动态形状的 tf.matmul,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53923996/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com