gpt4 book ai didi

python - 高维度的 Einsum

转载 作者:行者123 更新时间:2023-12-01 01:32:52 24 4
gpt4 key购买 nike

考虑下面的 3 个数组:

np.random.seed(0)

X = np.random.randint(10, size=(4,5))
W = np.random.randint(10, size=(3,4))
y = np.random.randint(3, size=(5,1))

我想将矩阵 X 的每一列与 W 的行相加并求和,以 y 作为索引。因此,例如,如果 y 中的第一个元素是 3 ,我会将 X 的第一列添加到 W 的第四行(python 中的索引 3)并求和。我会一遍又一遍地这样做,直到 X 的所有列都添加到 W 的特定行并求和。我可以用不同的方式做到这一点:1-使用for循环:

for i,j in enumerate(y):
W[j]+=X[:,i]

2-使用add.at函数

np.add.at(W,(y.ravel()),X.T)

3-但我不明白如何使用 einsum 来做到这一点。我得到了一个解决方案,但真的无法理解。

N = y.max()+1
W[:N] += np.einsum('ijk,lk->il',(np.arange(N)[:,None,None] == y.ravel()),X)

谁能解释一下这个结构?1 - (np.arange(N)[:,无,无] == y.ravel(),X)。我想这部分是指根据 y 将 X 的列与 W 的特定行相加。但W在哪里?为什么在这种情况下我们必须将 W 变换为 4 维?2- 'ijk,lk->il' - 我也不明白这一点。

i - 指行,j - 列,k- 每个元素,l - “l”也指什么?如果有人能理解这一点并向我解释,我将非常感激。提前致谢。

最佳答案

让我们通过删除一个维度并使用易于手动验证的值来简化问题:

W = np.zeros(3, np.int)
y = np.array([0, 1, 1, 2, 2])
X = np.array([1, 2, 3, 4, 5])

向量中的值WX 获取附加值通过查找y :

for i, j in enumerate(y):
W[j] += X[i]

W计算为[1, 5, 9] ,(手动快速检查)。

现在,如何对这段代码进行矢量化?我们不能做一个简单的W[y] += X[y]y其中包含重复值,并且不同的总和将在索引 1 和 2 处相互覆盖。

可以做的是将这些值广播到新的维度 len(y)然后对这个新创建的维度进行总结。

N = W.shape[0]
select = (np.arange(N) == y[:, None]).astype(np.int)

取索引范围W ( [0, 1, 2] ),并设置它们匹配的值 y在新维度中为 1,否则为 0。select包含这个数组:

array([[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1]])

它有len(y) == len(X)行和len(W)列并显示每个 y/行, W 的索引是什么它有助于。

让我们将 X 与这个数组相乘,mult = select * X[:, None] :

array([[1, 0, 0],
[0, 2, 0],
[0, 3, 0],
[0, 0, 4],
[0, 0, 5]])

我们已经有效地将 X 展开到一个新的维度中,并以某种方式对其进行排序,我们可以通过对新创建的维度求和将其形成 W 形状。各行的总和就是我们要添加到 W 的向量。 :

sum_Xy = np.sum(mult, axis=0)  # [1, 5, 9]
W += sum_Xy

select的计算和mult可以与 np.einsum 结合使用:

# `select` has shape (len(y)==len(X), len(W)), or `yw`
# `X` has shape len(X)==len(y), or `y`
# we want something `len(W)`, or `w`, and to reduce the other dimension
sum_Xy = np.einsum("yw,y->w", select, X)

这就是一维示例。对于问题中提出的二维问题,它是完全相同的方法:引入一个附加维度,广播 y索引,然后用 einsum 减少附加维度.

如果您了解一维示例的每个步骤如何工作,我相信您可以弄清楚代码如何在二维中执行此操作,因为这只是获得正确索引的问题(W 行, X 列)。

关于python - 高维度的 Einsum,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52672653/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com