gpt4 book ai didi

python - 遍历 numpy 数组有效地测试多个元素

转载 作者:太空宇宙 更新时间:2023-11-03 12:53:04 27 4
gpt4 key购买 nike

我有以下代码迭代名为“m”的 2d numpy 数组。它的工作速度非常慢。如何使用 numpy 函数转换此代码以避免使用 for 循环?

pairs = []
for i in range(size):
for j in range(size):
if(i >= j):
continue
if(m[i][j] + m[j][i] >= 0.75):
pairs.append([i, j, m[i][j] + m[j][i]])

最佳答案

您可以通过 NumPy 使用矢量化方法。想法是:

  • 首先初始化一个矩阵m然后创建 m+m.T相当于m[i][j] + m[j][i]其中 m.T是矩阵转置并称它为summ
  • np.triu (summ)返回矩阵的上三角部分(这相当于在代码中使用 continue 忽略下三角部分)。这避免了明确的 if(i >= j):在你的代码中。在这里你必须使用k=1排除对角元素。默认情况下,k=0其中也包括对角线元素。
  • 然后使用 np.argwhere 获得点的索引其中总和 m+m.T大于等于0.75
  • 然后将这些索引和相应的值存储在一个列表中,以供以后处理/打印。

可验证示例(使用小型 3x3 随机数据集)

import numpy as np

np.random.seed(0)
m = np.random.rand(3,3)
summ = m + m.T

index = np.argwhere(np.triu(summ, k=1)>=0.75)

pairs = [(x,y, summ[x,y]) for x,y in index]
print (pairs)
# # [(0, 1, 1.2600725493693163), (0, 2, 1.0403505873343364), (1, 2, 1.537667113848736)]

进一步的性能提升

我刚刚找到了一种更快的方法来生成最终的 pairs列出避免显式 for 循环的

pairs = list(zip(index[:, 0], index[:, 1], summ[index[:,0], index[:,1]]))

关于python - 遍历 numpy 数组有效地测试多个元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54431792/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com