gpt4 book ai didi

python - Cython、numpy 加速

转载 作者:行者123 更新时间:2023-11-30 23:05:41 25 4
gpt4 key购买 nike

我正在尝试编写一种算法来计算二维数组的某些相邻元素的平均值。

我想看看是否可以使用 Cython 来加速它,但这是我第一次自己使用它。

Python 版本:

import numpy as np

def clamp(val, minval, maxval):
return max(minval, min(val, maxval))


def filter(arr, r):
M = arr.shape[0]
N = arr.shape[1]

new_arr = np.zeros([M, N], dtype=np.int)

for x in range(M):
for y in range(N):
# Corner elements
p1 = clamp(x-r, 0, M)
p2 = clamp(y-r, 0, N)
p3 = clamp(y+r, 0, N-1)
p4 = clamp(x+r, 0, M-1)

nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4

tmp = 0

# End points
tmp += arr[p1, p2]
tmp += arr[p1, p3]
tmp += arr[p4, p2]
tmp += arr[p4, p3]

# The rest
tmp += sum(arr[p1+1:p4, p2])
tmp += sum(arr[p1+1:p4, p3])
tmp += sum(arr[p1, p2+1:p3])
tmp += sum(arr[p4, p2+1:p3])

new_arr[x, y] = tmp/nbr_elements

return new_arr

以及我对 Cython 实现的尝试。我发现如果重新实现 max/min/sum 会比使用 python 版本更快

Cython 版本:

from __future__ import division
import numpy as np
cimport numpy as np

DTYPE = np.int
ctypedef np.int_t DTYPE_t

cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b

def clamp(int val, int minval, int maxval):
return int_max(minval, int_min(val, maxval))

def cython_sum(np.ndarray[DTYPE_t, ndim=1] y):
cdef int N = y.shape[0]
cdef int x = y[0]
cdef int i
for i in xrange(1, N):
x += y[i]
return x


def filter(np.ndarray[DTYPE_t, ndim=2] arr, int r):
cdef M = im.shape[0]
cdef N = im.shape[1]

cdef np.ndarray[DTYPE_t, ndim=2] new_arr = np.zeros([M, N], dtype=DTYPE)
cdef int p1, p2, p3, p4, nbr_elements, tmp

for x in range(M):
for y in range(N):
# Corner elements
p1 = clamp(x-r, 0, M)
p2 = clamp(y-r, 0, N)
p3 = clamp(y+r, 0, N-1)
p4 = clamp(x+r, 0, M-1)

nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4

tmp = 0

# End points
tmp += arr[p1, p2]
tmp += arr[p1, p3]
tmp += arr[p4, p2]
tmp += arr[p4, p3]

# The rest
tmp += cython_sum(arr[p1+1:p4, p2])
tmp += cython_sum(arr[p1+1:p4, p3])
tmp += cython_sum(arr[p1, p2+1:p3])
tmp += cython_sum(arr[p4, p2+1:p3])

new_arr[x, y] = tmp/nbr_elements

return new_arr

我做了一个测试脚本:

import time
import numpy as np

import square_mean_py
import square_mean_cy

N = 500

arr = np.random.randint(15, size=(N, N))
r = 8

# Timing

t = time.time()
res_py = square_mean_py.filter(arr, r)
print time.time()-t

t = time.time()
res_cy = square_mean_cy.filter(arr, r)
print time.time()-t

哪个打印

9.61458301544
1.44476890564

这相当于大约的加速。 7次。我已经看到很多 Cython 实现可以产生更好的加速,所以我在想,也许你们中的一些人看到了加速算法的潜在方法?

最佳答案

您的 Cython 脚本存在一些问题:

  1. 您没有向 Cython 提供一些关键信息,例如范围中使用的 x、y、MN 的类型。
  2. 我已经 cdef 编辑了 cython_sumclamp 这两个函数,因为您在 Python 级别不需要它们。
  3. filter 函数中出现的 im 是什么?我假设您的意思是arr

修复这些问题我将重写/修改您的 Cython 脚本,如下所示:

from __future__ import division
import numpy as np
cimport numpy as np
from cython cimport boundscheck, wraparound

DTYPE = np.int
ctypedef np.int_t DTYPE_t

cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b

cdef int clamp3(int val, int minval, int maxval):
return int_max(minval, int_min(val, maxval))

@boundscheck(False)
cdef int cython_sum2(DTYPE_t[:] y):
cdef int N = y.shape[0]
cdef int x = y[0]
cdef int i
for i in range(1, N):
x += y[i]
return x

@boundscheck(False)
@wraparound(False)
def filter3(DTYPE_t[:,::1] arr, int r):
cdef int M = arr.shape[0]
cdef int N = arr.shape[1]

cdef np.ndarray[DTYPE_t, ndim=2, mode='c'] \
new_arr = np.zeros([M, N], dtype=DTYPE)
cdef int p1, p2, p3, p4, nbr_elements, tmp, x, y

for x in range(M):
for y in range(N):
# Corner elements
p1 = clamp3(x-r, 0, M)
p2 = clamp3(y-r, 0, N)
p3 = clamp3(y+r, 0, N-1)
p4 = clamp3(x+r, 0, M-1)

nbr_elements = (p3-p2-1)*2+(p4-p1-1)*2+4

tmp = 0

# End points
tmp += arr[p1, p2]
tmp += arr[p1, p3]
tmp += arr[p4, p2]
tmp += arr[p4, p3]

# The rest
tmp += cython_sum2(arr[p1+1:p4, p2])
tmp += cython_sum2(arr[p1+1:p4, p3])
tmp += cython_sum2(arr[p1, p2+1:p3])
tmp += cython_sum2(arr[p4, p2+1:p3])

new_arr[x, y] = <int>(tmp/nbr_elements)

return new_arr

这是我机器上的计时:

arr = np.random.randint(15, size=(500, 500))

Original (Python) version: 7.34 s
Your Cython version: 1.98 s
New Cython version: 0.0323 s

这比 Cython 脚本的速度快了近 60 倍,比原始 Python 脚本的速度快了 200 多倍。

关于python - Cython、numpy 加速,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33059600/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com