gpt4 book ai didi

python - 从按列排序的方阵中获取下对角线索引

转载 作者:太空宇宙 更新时间:2023-11-03 14:45:41 25 4
gpt4 key购买 nike

我正在尝试在 matplotlib.pyplot.subplot 的下对角线中创建成对图.因此,我需要来自方阵下对角线的索引。由于情节的顺序,我需要按列对它们进行排序。例如,假设我有以下 4x4 矩阵:

[ 1,  2,  3,  4]
[ 5, 6, 8, 7]
[ 8, 9, 10, 11]
[12, 13, 14, 15]

我需要按以下顺序排列它们的索引:5、8、12、9、13、14。我怎样才能用几行代码完成它?我将分享我的解决方案,但我觉得我可以以更优雅的方式完成它。

我的解决方案

>>> import numpy as np
>>> n = 4 # Matrix order
>>> a = np.arange(1,n*n+1).reshape(n,n)
>>> a
array([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
>>> index = np.triu_indices(n, 1)
>>> a.T[index]
array([ 5, 9, 13, 10, 14, 15])

上下文

接下来我要做的是:

>>> subplot_idx = a.T[index]
>>> for idx in subplot_idx:
... plt.subplot(n, n, idx)
... # plot something

最佳答案

一种更便宜的方法是避免索引创建部分并为 boolean-indexing 使用掩码.现在,由于它在 NumPy 中以行优先顺序排列并且我们需要较低的 diag 元素,因此我们需要在输入数组的转置版本上使用 upper diag 掩码(将 upper diag 元素设置为 True,其余为 False 的掩码)。我们将使用 broadcasting使用 ranged array outer comparison 高效地创建上层诊断掩码并索引到转置数组中。因此,对于输入数组 a 它将是 -

r = np.arange(len(a))
out = a.T[r[:,None] < r]

假设我们将使用小于 65536 x 65536 大小的数组的矩阵,我们可以为 r 使用较低的精度,从而实现显着的性能提升 -

r = np.arange(len(a), dtype=np.uint16)

相同的想法并使用 NumPy 内置 np.tri创建一个较低的诊断掩码并因此有一个优雅单行方式(如所要求的)将是 -

a.T[~np.tri(len(a), dtype=bool)]

sample 运行-

In [116]: a
Out[116]:
array([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])

In [117]: a.T[~np.tri(len(a), dtype=bool)]
Out[117]: array([ 5, 9, 13, 10, 14, 15])

基准测试

方法-

# Original soln
def extract_lower_diag_org(a):
n = len(a)
index = np.triu_indices(n, 1)
return a.T[index]

# Proposed soln
def extract_lower_diag_mask(a):
r = np.arange(len(a), dtype=np.uint16)
return a.T[r[:,None] < r]

更大阵列上的计时 -

In [142]: a = np.random.rand(5000,5000)

In [143]: %timeit extract_lower_diag_org(a)
1 loop, best of 3: 216 ms per loop

In [144]: %timeit extract_lower_diag_mask(a)
10 loops, best of 3: 50.2 ms per loop

In [145]: %timeit a.T[~np.tri(len(a), dtype=bool)]
10 loops, best of 3: 52.1 ms per loop

使用建议的基于掩码的方法,这些大型阵列的速度提高了 4x+

关于python - 从按列排序的方阵中获取下对角线索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49598488/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com