gpt4 book ai didi

python - python中的莱文森算法

转载 作者:塔克拉玛干 更新时间:2023-11-03 03:14:30 24 4
gpt4 key购买 nike

我需要确定许多不同的 Toeplitz 矩阵是否是奇异的。例如,我希望能够准确计算出有多少 12 x 12 0-1 Toeplitz 矩阵是奇异的。这是执行此操作的一些代码。

import itertools
from scipy.linalg import toeplitz
import scipy.linalg
import numpy as np

n = 12
singcount = 0
for longtuple in itertools.product([0,1], repeat = 2*n-1):
A = toeplitz(longtuple[0:n], longtuple[n-1:2*n-1])
if (np.isclose(scipy.linalg.det(A),0)):
singcount +=1
print singcount

但是 scipy.linalg.det 是一种非常低效的方法。原则Levinson Recursion速度更快,但我看不到如何实现它。谁能帮我入门(或者有更快更好的方法)?

最佳答案

我们需要加速 toeplitzdet 调用:

  • 2**k 批量大小工作
  • 首先创建一个toeplitz索引
  • 在 NumPy 1.8 中,det 是一个通用的 ufunc,可以在一次调用中计算出 may det。

代码:

import itertools
import numpy as np
from scipy.linalg import toeplitz, det

原代码如下:

%%time
n = 12
todo = itertools.islice(itertools.product([0,1], repeat = 2*n-1), 0, 2**16)
r1 = []
for longtuple in todo:
A = toeplitz(longtuple[0:n], longtuple[n-1:2*n-1])
r1.append(det(A))

优化后的代码如下:

%%time
batch = 2**10
todo = itertools.islice(itertools.product([0,1], repeat = 2*n-1), 0, 2**16)
idx = toeplitz(range(n), range(n-1, 2*n-1))

r2 = []
while True:
rows = list(itertools.islice(todo, 0, batch))
if not rows:
break
rows_arr = np.array(rows)
A = rows_arr[:, idx]
r2.extend(np.linalg.det(A).tolist())

这是时间结果:

original: Wall time: 4.65 s
optimized: Wall time: 646 ms

我们检查结果:

np.allclose(r1, r2)

你可以通过unpackbits()来提高速度:

%%time
r3 = []
todo = np.arange(0, 2**16).astype(np.uint32).byteswap().view(np.uint8).reshape(-1, 4)
for i in range(todo.shape[0]//batch):
B = np.unpackbits(todo[i*batch:(i+1)*batch], axis=-1)
rows_arr = B[:, -23:]
A = rows_arr[:, idx]
r3.extend(np.linalg.det(A).tolist())

时间是:

Wall time: 494 ms

这里是 n=10 的 singcount 的完整代码:

%%time
count = 0
batch = 2**10
n = 10
n2 = 10*2-1
idx = toeplitz(range(n), range(n-1, 2*n-1))
todo = np.arange(0, 2**n2).astype(np.uint32).byteswap().view(np.uint8).reshape(-1, 4)
for i in range(todo.shape[0]//batch):
B = np.unpackbits(todo[i*batch:(i+1)*batch], axis=-1)
rows_arr = B[:, -n2:]
A = rows_arr[:, idx]
det = np.linalg.det(A)
count += np.sum(np.isclose(det, 0))
print count

输出是 43892,在我的电脑上用了 2.15s。

关于python - python中的莱文森算法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/21562651/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com