gpt4 book ai didi

python - 从 Cython 中的列表调用方法

转载 作者:行者123 更新时间:2023-12-01 02:27:22 27 4
gpt4 key购买 nike

我想并行执行 3 个函数,在 Cython 的 prange 循环中采用相同的输入。他们在相同的变量 TVdu 上添加值,并采用相同的变量。代码的目的是计算四个主要方向上像素的梯度,然后计算逐像素的总变化。

为此,我创建一个包含方法名称的列表并迭代该列表。我有这个代码:

cdef void TV_norm(float[:, :] ux, float[:, :] uy, float[:, :] output, float epsilon, float p) nogil:
cdef int M = ux.shape[0]
cdef int N = ux.shape[1]
cdef int i, j
cdef float inv_p = 1./p
cdef float eps = epsilon**p

with parallel(num_threads=64):
for i in prange(M, schedule="guided"):
for j in range(N):
output[i, j] += (abs(ux[i, j])** p + abs(uy[i, j])** p + eps) **inv_p

cdef void center_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = np.roll(u, (di, 0)) - u
uy = np.roll(u, (0, dj)) - u
TV_norm(ux, uy, TV, epsilon, p)
du -= ux + uy


cdef void i_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = u - np.roll(u, (-di, 0))
uy = np.roll(u, (-di, dj)) - np.roll(u, (-di, 0))
TV_norm(ux, uy, TV, epsilon, p)
du += ux


cdef void j_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = np.roll(u, (di, -dj)) - np.roll(u, (0, -dj))
uy = u - np.roll(u, (0, -dj))
TV_norm(ux, uy, TV, epsilon, p)
du += uy


cdef list divTV_dual(float[:, :] u, float epsilon=0, float p=1):
cdef np.ndarray[DTYPE_t, ndim=2] TV = np.zeros_like(u)
cdef np.ndarray[DTYPE_t, ndim=2] du = TV.copy()
cdef list shifts = [[1, 1],[-1, 1],[1,-1],[-1, -1]]
cdef list methods = [center_diff, i_diff, j_diff]

with nogil, parallel(num_threads=4):
for i in prange(4, schedule="static"):
with gil:
di = shifts[i][0]
dj = shifts[i][1]
for j in range(3):
methods[j](u, TV, du, di, dj, epsilon, p)

return [du, TV]

虽然它在纯 Python 中工作,但 Cython 在编译时失败:

/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
2129 magic_arg_s = self.var_expand(line, stack_depth)
2130 with self.builtin_trap:
-> 2131 result = fn(magic_arg_s, cell)
2132 return result
2133

<decorator-gen-127> in cython(self, line, cell)

/usr/local/lib/python3.5/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
185 # but it's overkill for just that one bit of state.
186 def magic_deco(arg):
--> 187 call = lambda f, *a, **k: f(*a, **k)
188
189 if callable(arg):

/usr/local/lib/python3.5/dist-packages/Cython/Build/IpythonMagic.py in cython(self, line, cell)
289 build_extension.build_temp = os.path.dirname(pyx_file)
290 build_extension.build_lib = lib_dir
--> 291 build_extension.run()
292 self._code_cache[key] = module_name
293

/usr/lib/python3.5/distutils/command/build_ext.py in run(self)
336
337 # Now actually compile and link everything.
--> 338 self.build_extensions()
339
340 def check_extensions_list(self, extensions):

/usr/lib/python3.5/distutils/command/build_ext.py in build_extensions(self)
445 self._build_extensions_parallel()
446 else:
--> 447 self._build_extensions_serial()
448
449 def _build_extensions_parallel(self):

/usr/lib/python3.5/distutils/command/build_ext.py in _build_extensions_serial(self)
470 for ext in self.extensions:
471 with self._filter_build_errors(ext):
--> 472 self.build_extension(ext)
473
474 @contextlib.contextmanager

/usr/lib/python3.5/distutils/command/build_ext.py in build_extension(self, ext)
530 debug=self.debug,
531 extra_postargs=extra_args,
--> 532 depends=ext.depends)
533
534 # XXX outdated variable, kept here in case third-part code

/usr/lib/python3.5/distutils/ccompiler.py in compile(self, sources, output_dir, macros, include_dirs, debug, extra_preargs, extra_postargs, depends)
572 except KeyError:
573 continue
--> 574 self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
575
576 # Return *all* object filenames, not just the ones we just built.

/usr/lib/python3.5/distutils/unixccompiler.py in _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts)
118 extra_postargs)
119 except DistutilsExecError as msg:
--> 120 raise CompileError(msg)
121
122 def create_static_lib(self, objects, output_libname,

CompileError: command 'x86_64-linux-gnu-gcc' failed with exit status 1

有什么办法可以做到这一点吗?

<小时/>

编辑:

这个概念验证有效:

%%cython --compile-args=-O3 --compile-args=-ffast-math --compile-args=-fopenmp --link-args=-fopenmp

# cython: boundscheck=False
# cython: cdivision=True
# cython: wraparound=False
# cython: profile=True

cimport cython
from cython.parallel cimport parallel, prange

cdef foo(a):
print(a)

cdef bar(a):
print(a)

methods = [foo, bar]
cdef int i

with nogil, parallel():
for i in prange(2):
with gil:
methods[i]("a")

最佳答案

找到了...列表中调用的方法应该使用 cpdef 定义,而不是 cdef

我猜这是因为这些函数使用 numpy 类型和方法,所以它们需要暴露在 python 中。

关于python - 从 Cython 中的列表调用方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47245057/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com