gpt4 book ai didi

python - 如何沿批量维度广播 numpy 索引?

转载 作者:行者123 更新时间:2023-12-02 05:22:54 24 4
gpt4 key购买 nike

例如,np.array([[1,2],[3,4]])[np.triu_indices(2)] 的形状为 (3,),是上三角条目的扁平列表。但是,如果我有一批 2x2 矩阵:

foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)

并且我想获得每个矩阵的上三角索引,尝试的天真的做法是:

foo[:,np.triu_indices(2)]

但是,这个对象实际上具有 (30,2,3,2) 形状(与我们预期的 (30,3) 相反,如果我们有批量提取上三角条目。

我们如何沿着批量维度广播元组索引?

最佳答案

获取元组并使用它们来索引最后两个暗淡 -

r,c = np.triu_indices(2)
out = foo[:,r,c]

或者,带有 Ellipsis 的单行代码适用于 3D2D 数组 -

foo[(Ellipsis,)+np.triu_indices(2)]

它同样适用于 2D 数组 -

out = foo[r,c] # foo as 2D input array
<小时/>

遮蔽方式

3D阵列案例

我们还可以使用掩码进行基于掩码的方式 -

foo[:,~np.tri(2,k=-1, dtype=bool)]

二维数组案例

foo[~np.tri(2,k=-1, dtype=bool)]

关于python - 如何沿批量维度广播 numpy 索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58100302/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com