gpt4 book ai didi

Python 如何提高 numpy 数组的性能?

转载 作者:太空宇宙 更新时间:2023-11-04 00:32:05 24 4
gpt4 key购买 nike

我有一个全局 numpy.array 数据,它是一个 200*200*3 3d 数组,包含 3d 空间中的 40000 个点。

我的目标是计算每个点到单位立方体四个角的距离 ((0, 0, 0),(1, 0, 0),(0, 1, 0),(0, 0 , 1)), 所以我可以确定哪个角离它最近。

def dist(*point):
return np.linalg.norm(data - np.array(rgb), axis=2)

buffer = np.stack([dist(0, 0, 0), dist(1, 0, 0), dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0)

我编写了这段代码并对其进行了测试,每次运行大约需要 10 毫秒。我的问题是如何提高这段代码的性能,最好在不到 1 毫秒内运行。

最佳答案

你可以使用 Scipy cdist -

# unit cube coordinates as array
uc = np.array([[0, 0, 0],[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# buffer output
buf = cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)

运行时测试

# Original approach
def org_app():
return np.stack([dist(0, 0, 0), dist(1, 0, 0), \
dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0)

时间 -

In [170]: data = np.random.rand(200,200,3)

In [171]: %timeit org_app()
100 loops, best of 3: 4.24 ms per loop

In [172]: %timeit cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)
1000 loops, best of 3: 1.25 ms per loop

关于Python 如何提高 numpy 数组的性能?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45439382/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com