gpt4 book ai didi

python - 广播 np.dot vs tf.matmul 以进行张量矩阵乘法(形状必须为 2 阶但为 3 阶错误)

转载 作者:太空宇宙 更新时间:2023-11-04 09:50:33 26 4
gpt4 key购买 nike

假设我有以下张量:

X = np.zeros((3,201, 340))
Y = np.zeros((340, 28))

使用 numpy 成功生成 X、Y 的点积,并生成形状为 (3, 201, 28) 的张量。但是,对于 tensorflow,我收到以下错误:Shape must be rank 2 but is rank 3 error ...

最小代码示例:

X = np.zeros((3,201, 340))
Y = np.zeros((340, 28))
print(np.dot(X,Y).shape) # successful (3, 201, 28)
tf.matmul(X, Y) # errornous

知道如何使用 tensorflow 实现相同的结果吗?

最佳答案

由于您使用的是 tensors,因此使用 tensordot 比使用 np.dot 会更好(为了提高性能)。 NumPy 允许它 (numpy.dot) 通过降低性能在 tensors 上工作,而 tensorflow 似乎根本不允许它。

因此,对于 NumPy,我们将使用 np.tensordot -

np.tensordot(X, Y, axes=((2,),(0,)))

对于tensorflow,它将是tf.tensordot -

tf.tensordot(X, Y, axes=((2,),(0,)))

Related post to understand tensordot .

关于python - 广播 np.dot vs tf.matmul 以进行张量矩阵乘法(形状必须为 2 阶但为 3 阶错误),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47969305/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com