gpt4 book ai didi

python - 在 numpy 中查找对角线总和(更快)

转载 作者:太空狗 更新时间:2023-10-29 21:40:39 25 4
gpt4 key购买 nike

我有一些像这样的 board numpy 数组:

array([[0, 0, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 1, 0, 0]])

我使用以下代码求出棋盘从 -7 到 8 的每条第 n 条对角线上的元素总和(及其镜像版本)。

n = 8
rate = [b.diagonal(i).sum()
for b in (board, board[::-1])
for i in range(-n+1, n)]

经过一些分析后,这个操作占用了大约 2/3 的整体运行时间,这似乎是因为 2 个因素:

  • .diagonal 方法构建一个新数组而不是 View (看起来 numpy 1.7 将有一个新的 .diag 方法来解决这个问题)
  • 迭代是在列表推导中用 python 完成的

那么,有什么方法可以更快地找到这些和(可能在 numpy 的 C 层中)?


经过更多测试,我可以通过缓存此操作将总时间减少 7.5 倍...也许我在寻找错误的瓶颈?


还有一点:

刚刚找到了替换 diagonal(i).sum() 东西的 .trace 方法并且......性能没有太大改进(大约 2到 4%)。

所以问题应该是理解力。有什么想法吗?

最佳答案

有一个可能的解决方案使用 stride_tricks。这部分基于 this question 的答案中提供的大量信息。 ,但我认为问题已经足够不同,不能算作重复。这是应用于方阵的基本思想——请参阅下面的实现更通用解决方案的函数。

>>> cols = 8
>>> a = numpy.arange(cols * cols).reshape((cols, cols))
>>> fill = numpy.zeros((cols - 1) * cols, dtype='i8').reshape((cols - 1, cols))
>>> stacked = numpy.vstack((a, fill, a))
>>> major_stride, minor_stride = stacked.strides
>>> strides = major_stride, minor_stride * (cols + 1)
>>> shape = (cols * 2 - 1, cols)
>>> numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
array([[ 0, 9, 18, 27, 36, 45, 54, 63],
[ 8, 17, 26, 35, 44, 53, 62, 0],
[16, 25, 34, 43, 52, 61, 0, 0],
[24, 33, 42, 51, 60, 0, 0, 0],
[32, 41, 50, 59, 0, 0, 0, 0],
[40, 49, 58, 0, 0, 0, 0, 0],
[48, 57, 0, 0, 0, 0, 0, 0],
[56, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 7],
[ 0, 0, 0, 0, 0, 0, 6, 15],
[ 0, 0, 0, 0, 0, 5, 14, 23],
[ 0, 0, 0, 0, 4, 13, 22, 31],
[ 0, 0, 0, 3, 12, 21, 30, 39],
[ 0, 0, 2, 11, 20, 29, 38, 47],
[ 0, 1, 10, 19, 28, 37, 46, 55]])
>>> diags = numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
>>> diags.sum(axis=1)
array([252, 245, 231, 210, 182, 147, 105, 56, 7, 21, 42, 70, 105,
147, 196])

当然,我不知道这实际上有多快。但我敢打赌它会比 Python 列表理解更快。

为方便起见,这里有一个完全通用的 diagonals 函数。它假定您要沿最长轴移动对角线:

def diagonals(a):
rows, cols = a.shape
if cols > rows:
a = a.T
rows, cols = a.shape
fill = numpy.zeros(((cols - 1), cols), dtype=a.dtype)
stacked = numpy.vstack((a, fill, a))
major_stride, minor_stride = stacked.strides
strides = major_stride, minor_stride * (cols + 1)
shape = (rows + cols - 1, cols)
return numpy.lib.stride_tricks.as_strided(stacked, shape, strides)

关于python - 在 numpy 中查找对角线总和(更快),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/10792897/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com