gpt4 book ai didi

python - Cython:如何声明 numpy.argwhere()

转载 作者:行者123 更新时间:2023-11-28 18:16:03 34 4
gpt4 key购买 nike

我尝试将我的部分代码 Cythonize 如下,以希望获得一些速度:

# cython: boundscheck=False
import numpy as np
cimport numpy as np
import time

cpdef object my_function(np.ndarray[np.double_t, ndim = 1] array_a,
np.ndarray[np.double_t, ndim = 1] array_b,
int n_rows,
int n_columns):
cdef double minimum_of_neighbours, difference, change
cdef int i
cdef np.ndarray[np.int_t, ndim =1] locations
locations = np.argwhere(array_a > 0)

for i in locations:
minimum_of_neighbours = min(array_a[i - n_columns], array_a[i+1], array_a[i + n_columns], array_a[i-1])
if array_a[i] - minimum_of_neighbours < 0:
difference = minimum_of_neighbours - array_a[i]
change = min(difference, array_a[i] / 5.)
array_a[i] += change
array_b[i] -= change * 5.
print time.time()

return array_a, array_b

我可以毫无错误地编译它,但是当我使用这个函数时,我得到了这个错误:

from cythonized_code import my_function
import numpy as np

array_a = np.random.uniform(low=-100, high=100, size = 100).astype(np.double)
array_b = np.random.uniform(low=0, high=20, size = 100).astype(np.double)

a, b = my_function(array_a,array_b,5,20)

# which gives me this error:
# locations = np.argwhere(array_a > 0)
# ValueError: Buffer has wrong number of dimensions (expected 1, got 2)

我需要在这里声明 locations 类型吗?我之所以要声明它,是因为它在编译代码生成的带注释的 HTML 文件中是黄色的。

最佳答案

最好不要使用 python 功能 locations[i],因为它太昂贵了:Python 会从低级 c 整数(这是存储在 locations-numpy 数组中的内容),将其注册到垃圾收集器中,然后将其转换回 int,销毁 Python 对象 - 相当大的开销.

要直接访问低级 c 整数,需要将 locations 绑定(bind)到一种类型。正常的做法是查找 locations 具有哪些属性:

>>> locations.ndim
2
>>> locations.dtype
dtype('int64')

转换为 cdef np.ndarray[np.int64_t, ndim =2] locations

但是,由于 Cython-quirk,这将(可能,现在无法检查)不足以摆脱 Python-overhead:

for i in locations:
...

不会被解释为原始数组访问,但会调用 Python 机制。参见示例 here .

因此您必须将其更改为:

for index in range(len(locations)):
i=locations[index][0]

然后 Cython“理解”,您想要访问原始 c-int64 数组。


  • 实际上,这并不完全正确:在这种情况下,首先创建一个 nd.array(例如 locations[0]locations[1]) 然后调用 __Pyx_PyInt_As_int(这或多或少是 [PyLong_AsLongAndOverflow][2] 的别名),它创建了一个 PyLongObject,在临时 PyLongObjectnd.array 被破坏之前,从中获取 C-int 值。

在这里我们很幸运,因为长度为 1 的 numpy 数组可以转换为 Python 标量。如果 locations 的第二个维度是 >1,则代码将不起作用。

关于python - Cython:如何声明 numpy.argwhere(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48258260/

34 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com