gpt4 book ai didi

python - numpy.ndarray 枚举适当的维度子集?

转载 作者:太空狗 更新时间:2023-10-29 22:23:28 26 4
gpt4 key购买 nike

(在这篇文章中,让 np 成为 numpy 的简写。)

假设 a 是一个 (n + k) 维 np.ndarray 对象,对于一些整数n> 1 和 k> 1。(IOW,n + k> 3 是 a.ndim).我想枚举 a 的第一个 n 维度;这意味着,在每次迭代中,枚举器/迭代器都会生成一对,其第一个元素是 iin 索引,第二个元素是 ka[ii] 处的维度子ndarray

当然,编写一个函数来执行此操作并不难(事实上,我在下面给出了这样一个函数的示例),但我想知道这一点:

does numpy provide any special syntax or functions for carrying out this type of "partial" enumeration?

(通常,当我想遍历多维np.ndarray 对象时,我使用np.ndenumerate,但它在这里没有帮助,因为(如据我所知)np.ndenumerate 将遍历所有 n + k 个维度。)

假设上述问题的答案是肯定的,那么后续就是:

what about the case where the n dimensions to iterate over are not contiguous?

(在这种情况下,枚举器/迭代器在每次迭代中返回的对的第一个元素将是 r> n 元素的元组,其中一些将是表示“全部”的特殊值,例如 slice(None);这对的第二个元素仍将是长度为 kndarray >.)

谢谢!


下面的代码有望阐明问题规范。函数 partial_enumerate 使用可用于该目的的任何特殊 numpy 结构来完成我想做的事情。 partial_enumerate 的定义之后是 n = k = 2 的简单示例。

import numpy as np
import itertools as it
def partial_enumerate(nda, n):
"""Enumerate over the first N dimensions of the numpy.ndarray NDA.

Returns an iterator of pairs. The first element of each pair is a tuple
of N integers, corresponding to a partial index I into NDA; the second element
is the subarray of NDA at I.
"""

# ERROR CHECKING & HANDLING OMITTED
for ii in it.product(*[range(d) for d in nda.shape[:n]]):
yield ii, nda[ii]

a = np.zeros((2, 3, 4, 5))
for ii, vv in partial_enumerate(a, 2):
print ii, vv.shape

输出的每一行都是“一对元组”,其中第一个元组表示 a 中的部分 n 坐标集,第二个表示a 在这些部分坐标处的 k 维子数组的形状; (第二对的值对于所有行都是相同的,正如数组的规律性所预期的那样):

(0, 0) (4, 5)
(0, 1) (4, 5)
(0, 2) (4, 5)
(1, 0) (4, 5)
(1, 1) (4, 5)
(1, 2) (4, 5)

相比之下,在这种情况下迭代 np.ndenumerate(a) 将导致 a.size 迭代,每次访问 a.

最佳答案

您可以使用 numpy 广播规则生成笛卡尔积。 numpy.ix_ 函数创建适当数组的列表。它等效于以下内容:

>>> def pseudo_ix_gen(*arrays):
... base_shape = [1 for arr in arrays]
... for dim, arr in enumerate(arrays):
... shape = base_shape[:]
... shape[dim] = len(arr)
... yield numpy.array(arr).reshape(shape)
...
>>> def pseudo_ix_(*arrays):
... return list(pseudo_ix_gen(*arrays))

或者,更简洁:

>>> def pseudo_ix_(*arrays):
... shapes = numpy.diagflat([len(a) - 1 for a in arrays]) + 1
... return [numpy.array(a).reshape(s) for a, s in zip(arrays, shapes)]

结果是一个可广播数组列表:

>>> numpy.ix_(*[[2, 4], [1, 3], [0, 2]])
[array([[[2]],

[[4]]]), array([[[1],
[3]]]), array([[[0, 2]]])]

将此与 numpy.ogrid 的结果进行比较:

>>> numpy.ogrid[0:2, 0:2, 0:2]
[array([[[0]],

[[1]]]), array([[[0],
[1]]]), array([[[0, 1]]])]

如您所见,它是相同的,但是 numpy.ix_ 允许您使用非连续索引。现在,当我们应用 numpy 广播规则时,我们得到笛卡尔积:

>>> list(numpy.broadcast(*numpy.ix_(*[[2, 4], [1, 3], [0, 2]])))
[(2, 1, 0), (2, 1, 2), (2, 3, 0), (2, 3, 2),
(4, 1, 0), (4, 1, 2), (4, 3, 0), (4, 3, 2)]

如果我们不是将 numpy.ix_ 的结果传递给 numpy.broadcast,而是使用它来索引数组,我们会得到:

>>> a = numpy.arange(6 ** 4).reshape((6, 6, 6, 6))
>>> a[numpy.ix_(*[[2, 4], [1, 3], [0, 2]])]
array([[[[468, 469, 470, 471, 472, 473],
[480, 481, 482, 483, 484, 485]],

[[540, 541, 542, 543, 544, 545],
[552, 553, 554, 555, 556, 557]]],


[[[900, 901, 902, 903, 904, 905],
[912, 913, 914, 915, 916, 917]],

[[972, 973, 974, 975, 976, 977],
[984, 985, 986, 987, 988, 989]]]])

但是,买者自负。可广播数组对于索引很有用,但如果您真的想枚举值,最好使用 itertools.product:

>>> %timeit list(itertools.product(range(5), repeat=5))
10000 loops, best of 3: 196 us per loop
>>> %timeit list(numpy.broadcast(*numpy.ix_(*([range(5)] * 5))))
100 loops, best of 3: 2.74 ms per loop

因此,如果您无论如何都要合并一个 for 循环,那么 itertools.product 可能会更快。尽管如此,您仍然可以使用上述方法在纯 numpy 中获得一些类似的数据结构:

>> pgrid_idx = numpy.ix_(*[[2, 4], [1, 3], [0, 2]])
>>> sub_indices = numpy.rec.fromarrays(numpy.indices((6, 6, 6)))
>>> a[pgrid_idx].reshape((8, 6))
array([[468, 469, 470, 471, 472, 473],
[480, 481, 482, 483, 484, 485],
[540, 541, 542, 543, 544, 545],
[552, 553, 554, 555, 556, 557],
[900, 901, 902, 903, 904, 905],
[912, 913, 914, 915, 916, 917],
[972, 973, 974, 975, 976, 977],
[984, 985, 986, 987, 988, 989]])
>>> sub_indices[pgrid_idx].reshape((8,))
rec.array([(2, 1, 0), (2, 1, 2), (2, 3, 0), (2, 3, 2),
(4, 1, 0), (4, 1, 2), (4, 3, 0), (4, 3, 2)],
dtype=[('f0', '<i8'), ('f1', '<i8'), ('f2', '<i8')])

关于python - numpy.ndarray 枚举适当的维度子集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/9570050/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com