gpt4 book ai didi

numpy - `out` 中的 `numpy.einsum` 参数无法按预期工作

转载 作者:行者123 更新时间:2023-12-04 18:01:25 24 4
gpt4 key购买 nike

我有两段代码。第一个是:

A = np.arange(3*4*3).reshape(3, 4, 3)
P = np.arange(1, 4)
A[:, 1:, :] = np.einsum('j, ijk->ijk', P, A[:, 1:, :])

结果 A是 :
array([[[  0,   1,   2],
[ 6, 8, 10],
[ 18, 21, 24],
[ 36, 40, 44]],

[[ 12, 13, 14],
[ 30, 32, 34],
[ 54, 57, 60],
[ 84, 88, 92]],

[[ 24, 25, 26],
[ 54, 56, 58],
[ 90, 93, 96],
[132, 136, 140]]])

第二个是:
A = np.arange(3*4*3).reshape(3, 4, 3)
P = np.arange(1, 4)
np.einsum('j, ijk->ijk', P, A[:, 1:, :], out=A[:,1:,:])

结果 A是 :
array([[[ 0,  1,  2],
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],

[[12, 13, 14],
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],

[[24, 25, 26],
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]])

所以结果是不同的。这里我想用 out以节省内存。是否是 numpy.einsum 中的错误? ?还是我错过了什么?

顺便说一句,我的 numpy版本是 1.13.3。

最佳答案

我没用过这个新out之前的参数,但曾使用过 einsum过去,并对它的工作原理有一个大致的了解(或至少习惯了)。

在我看来它初始化了 out在迭代开始前将数组归零。这将解释 A[:,1:,:] 中的所有 0堵塞。如果相反我最初分开 out数组,插入所需的值

In [471]: B = np.ones((3,4,3),int)
In [472]: np.einsum('j, ijk->ijk', P, A[:, 1:, :], out=B[:,1:,:])
Out[472]:
array([[[ 3, 4, 5],
[ 12, 14, 16],
[ 27, 30, 33]],

[[ 15, 16, 17],
[ 36, 38, 40],
[ 63, 66, 69]],

[[ 27, 28, 29],
[ 60, 62, 64],
[ 99, 102, 105]]])
In [473]: B
Out[473]:
array([[[ 1, 1, 1],
[ 3, 4, 5],
[ 12, 14, 16],
[ 27, 30, 33]],

[[ 1, 1, 1],
[ 15, 16, 17],
[ 36, 38, 40],
[ 63, 66, 69]],

[[ 1, 1, 1],
[ 27, 28, 29],
[ 60, 62, 64],
[ 99, 102, 105]]])
einsum 的 Python 部分没有告诉我太多,除了它决定如何通过 out数组到 c部分,(作为 tmp_operands 的列表之一):

c_einsum(einsum_str, *tmp_operands, **einsum_kwargs)

我知道它设置了一个 c-api相当于 np.nditer ,使用 str定义轴和迭代。

它在迭代教程中迭代了类似本节的内容:

https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.nditer.html#reduction-iteration

特别注意 it.reset()步。这设置了 out在迭代之前缓冲到 0。然后迭代输入数组和输出数组的元素,将计算值写入输出元素。由于它正在做产品的总和(例如 out[:] += ... ),它必须从一个干净的石板开始。

我有点猜测实际发生了什么,但对我来说应该将输出缓冲区清零开始是合乎逻辑的。如果该数组与输入之一相同,则最终会干扰计算。

所以我认为这种方法不会奏效并节省您的内存。它需要一个干净的缓冲区来累积结果。一旦完成,或者您可以将值写回 A .但鉴于 dot 的性质像产品一样,您不能将相同的数组用于输入和输出。
In [476]: A[:,1:,:] = np.einsum('j, ijk->ijk', P, A[:, 1:, :])
In [477]: A
Out[477]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 12, 14, 16],
[ 27, 30, 33]],
....)

关于numpy - `out` 中的 `numpy.einsum` 参数无法按预期工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47492316/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com