gpt4 book ai didi

python - numpy Python 中 (n,n,M) 和 (n,n) 矩阵的矩阵乘法

转载 作者:太空宇宙 更新时间:2023-11-04 07:55:31 24 4
gpt4 key购买 nike

我正在使用 Python。

A.shape=(n,n,M)B.shape=(n,n) 我想做以下事情:

AB = np.array_like(A)
for m in range(M):
AB[:,:,m]=A[:,:,m] @ B

然而,这段代码似乎并不是执行此操作的最有效方法?

最佳答案

我们可以使用np.tensordot -

np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)

Related post to understand tensordot .

在引擎盖下,它执行 reshape ,通过置换对齐轴,然后将基于 BLAS 的矩阵乘法与 np.dot 结合使用。那些肮脏的工作看起来像这样——

A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)

B 开始,它会是这样的 -

B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)

基准测试

设置 -

np.random.seed(0)
n,M = 50,50
A = np.random.rand(n,n,M)
B = np.random.rand(n,n)

时间 -

# @Psidom's soln-1
In [18]: %timeit np.einsum('ijk,jl->ilk', A, B)
100 loops, best of 3: 10.2 ms per loop

# @Psidom's soln-2
In [19]: %timeit (A.transpose(2,0,1) @ B).transpose(1,2,0)
100 loops, best of 3: 10.7 ms per loop

# @Psidom's einsum soln-1 with optimize set as True
In [20]: %timeit np.einsum('ijk,jl->ilk', A, B,optimize=True)
1000 loops, best of 3: 1.17 ms per loop

In [21]: %timeit np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)
1000 loops, best of 3: 1.09 ms per loop

In [22]: %timeit A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)
1000 loops, best of 3: 1.03 ms per loop

In [23]: %timeit B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)
1000 loops, best of 3: 951 µs per loop

关于python - numpy Python 中 (n,n,M) 和 (n,n) 矩阵的矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49479897/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com