gpt4 book ai didi

python - Numpy 减去两个 ndim 相等但形状不同的数组

转载 作者:行者123 更新时间:2023-12-01 01:08:47 27 4
gpt4 key购买 nike

所以我有两个 ndarray:

形状为 (N,a,a) 的 A,基本上是 N 个形状为 (a,a) 的数组的堆栈

B 形状为 (8,M,a,a),由 8 x M 形状 (a,a) 数组组成的矩阵

我需要从 A (A-B) 中减去 B,这样得到的数组的形状为 (8,M*N,a,a)。更详细地说,B 的 8 个数组中的每一个(总共 M 个)都需要从 A 中的每个数组中减去,从而导致 (a,a) 形状数组之间进行 8*M*N 次减法。

如何在没有循环的情况下以矢量化方式执行此操作?这个thread做了类似的事情,但在较低的维度,我不知道如何扩展它。

最佳答案

A = np.arange(8).reshape(2,2,2)
B = np.ones(shape=(8,4,2,2))

如果维度相同或一个维度为 1,则一般广播有效,因此我们这样做;

a = A[np.newaxis, :, np.newaxis, :, :]
b = B[:, np.newaxis, :, :, :]

a.shape # <- (1,2,1,2,2)
b.shape # <- (8,1,4,2,2)

现在你可以进行广播了

c = a - b
c.shape # <- (8,2,4,2,2)

当您 reshape (2x4=8) 组件时,它们就会对齐。

c.reshape(8,-1,2,2) 

新轴的顺序决定了 reshape ,所以要小心。

关于python - Numpy 减去两个 ndim 相等但形状不同的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55070897/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com