gpt4 book ai didi

python - 定义 numpy 索引数组

转载 作者:太空宇宙 更新时间:2023-11-04 11:08:23 25 4
gpt4 key购买 nike

我对 numpy 索引有点困惑。假设我有一个三维数组,例如:

test_arr = np.arange(3*2*3).reshape(3,2,3)
test_arr
array([[[ 0, 1, 2],
[ 3, 4, 5]],

[[ 6, 7, 8],
[ 9, 10, 11]],

[[12, 13, 14],
[15, 16, 17]]])

我想通过一个沿维度 1 的 bool 数组对其进行索引:

dim1_idx = np.array([True, False])
test_arr[:, dim1_idx, :]

这给了我

array([[[ 0,  1,  2]],

[[ 6, 7, 8]],

[[12, 13, 14]]])

到目前为止一切都很好。

我的问题是,有没有一种方法可以让我提前定义这个 bool 索引数组——比如(这行不通):

all_dim_idx = dim1_idx[np.newaxis, :, np.newaxis]
test_arr[all_dim_idx]

我意识到这不是因为它不能以一种方式广播使 all_dim_idx 数组适合 test_arr。我可以使用 np.tile 或 np.reshape 使索引数组适合更大的数组,但是(并且不能再推广到其他数组形状)我只是觉得可能有更好的方法。谁能赐教一下?

提前致谢!

最佳答案

In [600]: test_arr = np.arange(3*2*3).reshape(3,2,3)                            
In [601]: dim1_idx = np.array([True, False])

定义一个索引元组:

In [602]: idx = (slice(None), dim1_idx, slice(None))                            
In [603]: test_arr[idx]
Out[603]:
array([[[ 0, 1, 2]],

[[ 6, 7, 8]],

[[12, 13, 14]]])

关于python - 定义 numpy 索引数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58868215/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com