gpt4 book ai didi

python - 使用 Cython 包装 LAPACKE 函数

转载 作者:太空狗 更新时间:2023-10-30 01:15:14 25 4
gpt4 key购买 nike

我正在尝试包装 LAPACK 函数 dgtsv (三对角方程组的求解器)使用 Cython。

我遇到了 this previous answer ,但由于 dgtsv 不是包装在 scipy.linalg 中的 LAPACK 函数之一,我认为我无法使用这种特殊方法。相反,我一直在尝试关注 this example .

这是我的 lapacke.pxd 文件的内容:

ctypedef int lapack_int

cdef extern from "lapacke.h" nogil:

int LAPACK_ROW_MAJOR
int LAPACK_COL_MAJOR

lapack_int LAPACKE_dgtsv(int matrix_order,
lapack_int n,
lapack_int nrhs,
double * dl,
double * d,
double * du,
double * b,
lapack_int ldb)

...这是我在 _solvers.pyx 中的薄 Cython 包装器:

#!python

cimport cython
from lapacke cimport *

cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
double[:, ::1] B):

cdef:
lapack_int n = D.shape[0]
lapack_int nrhs = B.shape[1]
lapack_int ldb = B.shape[0]
double * dl = &DL[0]
double * d = &D[0]
double * du = &DU[0]
double * b = &B[0, 0]
lapack_int info

info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)

return info

...这是一个 Python 包装器和测试脚本:

import numpy as np
from scipy import sparse
from cymodules import _solvers


def trisolve_lapacke(dl, d, du, b, inplace=False):

if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
or b.shape != d.shape):
raise ValueError('Invalid diagonal shapes')

if b.ndim == 1:
# b is (LDB, NRHS)
b = b[:, None]

# be sure to force a copy of d and b if we're not solving in place
if not inplace:
d = d.copy()
b = b.copy()

# this may also force copies if arrays are improperly typed/noncontiguous
dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
for v in (dl, d, du, b))

# b will now be modified in place to contain the solution
info = _solvers.TDMA_lapacke(dl, d, du, b)
print info

return b.ravel()


def test_trisolve(n=20000):

dl = np.random.randn(n - 1)
d = np.random.randn(n)
du = np.random.randn(n - 1)

M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
x = np.random.randn(n)
b = M.dot(x)

x_hat = trisolve_lapacke(dl, d, du, b)

print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)

不幸的是,test_trisolve 只是在调用 _solvers.TDMA_lapacke 时出现段错误。我很确定我的 setup.py 是正确的 - ldd _solvers.so 显示 _solvers.so 被链接到正确的共享库在运行时。

我不太确定如何从这里开始 - 有什么想法吗?


简要更新:

对于较小的 n 值,我往往不会立即得到段错误,但我确实得到了无意义的结果(||x - x_hat|| 应该非常接近0):

In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 6.23202576396

In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| = 3.88623414288

In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 2.60190676562

In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 3.86631743386

In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault

通常 LAPACKE_dgtsv 返回代码 0(这应该表示成功),但偶尔我会得到 -7,这意味着参数 7 ( b) 具有非法值。发生的情况是只有 b 的第一个值实际上被修改了。如果我继续调用 test_trisolve,即使 n 很小,我最终也会遇到段错误。

最佳答案

好吧,我最终弄明白了 - 看来我误解了在这种情况下行和列主要指的是什么。

由于 C 连续数组遵循行优先顺序,我假设我应该将 LAPACK_ROW_MAJOR 指定为 LAPACKE_dgtsv 的第一个参数。

其实如果我改变

info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, ...)

info = LAPACKE_dgtsv(LAPACK_COL_MAJOR, ...)

那么我的函数就可以工作了:

test_trisolve2.test_trisolve()
0
||x - x_hat|| = 6.67064747632e-12

这对我来说似乎很违反直觉 - 谁能解释为什么会这样?

关于python - 使用 Cython 包装 LAPACKE 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23200085/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com