gpt4 book ai didi

python - 将 2-d 矩阵的每一列乘以 3-d 矩阵的每个切片的更有效方法

转载 作者:太空宇宙 更新时间:2023-11-03 14:04:44 25 4
gpt4 key购买 nike

我有一个 8x8x25000 的数组 W 和一个 8 x 25000 的数组 r。我想将 W 的每个 8x8 切片乘以 r 的每一列 (8x1),并将结果保存在 Wres 中,最终将成为一个 8x25000 矩阵。

我正在使用这样的 for 循环来完成此操作:

for i in range(0,25000):
Wres[:,i] = np.matmul(W[:,:,i],res[:,i])

但这很慢,我希望有更快的方法来完成此任务。

有什么想法吗?

最佳答案

只要 2 个数组共享相同的 1 轴长度,Matmul 就可以传播。来自文档:

If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

因此,您必须在 matmul 之前执行 2 个操作:

import numpy as np
a = np.random.rand(8,8,100)
b = np.random.rand(8, 100)
  1. 调换 ab 使第一个轴是 100 个切片
  2. b 添加一个额外的维度,以便 b.shape = (100, 8, 1)

然后:

 at = a.transpose(2, 0, 1) # swap to shape 100, 8, 8
bt = b.T[..., None] # swap to shape 100, 8, 1
c = np.matmul(at, bt)

c 现在是 100, 8, 1, reshape 回 8, 100:

 c = np.squeeze(c).swapaxes(0, 1)

 c = np.squeeze(c).T

最后,为了方便起见,单线:

c = np.squeeze(np.matmul(a.transpose(2, 0, 1), b.T[..., None])).T

关于python - 将 2-d 矩阵的每一列乘以 3-d 矩阵的每个切片的更有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44980238/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com