gpt4 book ai didi

python - 通过 numpy 坐标数组索引 numpy 数组

转载 作者:行者123 更新时间:2023-11-28 18:17:32 25 4
gpt4 key购买 nike

假设我们有

  • 一个n维的numpy.array A
  • 一个 dtype=int 且形状为 (n, m) 的 numpy.array B

我如何通过 B 对 A 进行索引,以便结果是形状为 (m,) 的数组,其值取自 B 的列指示的位置?

例如,当 B 是一个 python 列表时,请考虑执行我想要的操作的这段代码:

>>> a = np.arange(27).reshape(3,3,3)
>>> a[[0, 1, 2], [0, 0, 0], [1, 1, 2]]
array([ 1, 10, 20]) # the result we're after
>>> bl = [[0, 1, 2], [0, 0, 0], [1, 1, 2]]
>>> a[bl]
array([ 1, 10, 20]) # also works when indexing with a python list
>>> a[bl].shape
(3,)

但是,当B是一个numpy数组时,结果就不同了:

>>> b = np.array(bl)
>>> a[b].shape
(3, 3, 3, 3)

现在,我可以通过将 B 转换为元组来获得所需的结果,但这肯定不是正确/惯用的方法吗?

>>> a[tuple(b)]
array([ 1, 10, 20])

是否有一个 numpy 函数可以在不将 B 转换为元组的情况下实现相同的功能?

最佳答案

一种替代方法是转换为线性索引,然后使用 np.take 索引或索引到其扁平化版本中 -

np.take(a,np.ravel_multi_index(b, a.shape))
a.flat[np.ravel_multi_index(b, a.shape)]

自定义 np.ravel_multi_index 以提升性能

我们可以实现一个自定义版本来模拟 np.ravel_multi_index 的行为来提升性能,就像这样 -

def ravel_index(b, shp):
return np.concatenate((np.asarray(shp[1:])[::-1].cumprod()[::-1],[1])).dot(b)

使用它,可以通过两种方式找到所需的输出 -

np.take(a,ravel_index(b, a.shape))
a.flat[ravel_index(b, a.shape)]

基准测试

另外还结合了问题中基于 tuple 的方法和@Kanak 帖子中基于 map 的方法。

案例 #1:dims = 3

In [23]: a = np.random.randint(0,9,([20]*3))

In [24]: b = np.random.randint(0,20,(a.ndim,1000000))

In [25]: %timeit a[tuple(b)]
...: %timeit a[map(np.ravel, b)]
...: %timeit np.take(a,np.ravel_multi_index(b, a.shape))
...: %timeit a.flat[np.ravel_multi_index(b, a.shape)]
...: %timeit np.take(a,ravel_index(b, a.shape))
...: %timeit a.flat[ravel_index(b, a.shape)]
100 loops, best of 3: 6.56 ms per loop
100 loops, best of 3: 6.58 ms per loop
100 loops, best of 3: 6.95 ms per loop
100 loops, best of 3: 9.17 ms per loop
100 loops, best of 3: 6.31 ms per loop
100 loops, best of 3: 8.52 ms per loop

案例 #2:dims = 6

In [29]: a = np.random.randint(0,9,([10]*6))

In [30]: b = np.random.randint(0,10,(a.ndim,1000000))

In [31]: %timeit a[tuple(b)]
...: %timeit a[map(np.ravel, b)]
...: %timeit np.take(a,np.ravel_multi_index(b, a.shape))
...: %timeit a.flat[np.ravel_multi_index(b, a.shape)]
...: %timeit np.take(a,ravel_index(b, a.shape))
...: %timeit a.flat[ravel_index(b, a.shape)]
10 loops, best of 3: 40.9 ms per loop
10 loops, best of 3: 40 ms per loop
10 loops, best of 3: 20 ms per loop
10 loops, best of 3: 29.9 ms per loop
100 loops, best of 3: 15.7 ms per loop
10 loops, best of 3: 25.8 ms per loop

案例 #3:dims = 10

In [32]: a = np.random.randint(0,9,([4]*10))

In [33]: b = np.random.randint(0,4,(a.ndim,1000000))

In [34]: %timeit a[tuple(b)]
...: %timeit a[map(np.ravel, b)]
...: %timeit np.take(a,np.ravel_multi_index(b, a.shape))
...: %timeit a.flat[np.ravel_multi_index(b, a.shape)]
...: %timeit np.take(a,ravel_index(b, a.shape))
...: %timeit a.flat[ravel_index(b, a.shape)]
10 loops, best of 3: 60.7 ms per loop
10 loops, best of 3: 60.1 ms per loop
10 loops, best of 3: 27.8 ms per loop
10 loops, best of 3: 38 ms per loop
100 loops, best of 3: 18.7 ms per loop
10 loops, best of 3: 29.3 ms per loop

因此,在处理高维输入和大数据时寻找替代方案是有意义的。

关于python - 通过 numpy 坐标数组索引 numpy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47370718/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com