gpt4 book ai didi

python - 加速 numpy kronecker 产品

转载 作者:太空狗 更新时间:2023-10-29 21:35:28 24 4
gpt4 key购买 nike

我正在从事我的第一个大型 Python 项目。我有一个函数,其中包含以下代码:

            # EXPAND THE EXPECTED VALUE TO APPLY TO ALL STATES,
# THEN UPDATE fullFnMat
EV_subset_expand = np.kron(EV_subset, np.ones((nrows, 1)))
fullFnMat[key] = staticMat[key] + EV_subset_expand

在我的代码分析器中,这个 kronecker 产品似乎占用了大量时间。

Function                                                                                        was called by...
ncalls tottime cumtime
/home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/BellmanEquation.py:17(bellmanFn) <- 19 37.681 38.768 /home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:467(solveTheModel)
{numpy.core.multiarray.concatenate} <- 342 27.319 27.319 /usr/lib/pymodules/python2.7/numpy/lib/shape_base.py:665(kron)
/home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:467(solveTheModel) <- 1 11.041 91.781 <string>:1(<module>)
{method 'argsort' of 'numpy.ndarray' objects} <- 19 7.692 7.692 /usr/lib/pymodules/python2.7/numpy/core/fromnumeric.py:597(argsort)
/usr/lib/pymodules/python2.7/numpy/core/numeric.py:789(outer) <- 171 2.526 2.527 /usr/lib/pymodules/python2.7/numpy/lib/shape_base.py:665(kron)
{method 'max' of 'numpy.ndarray' objects} <- 209 2.034 2.034 /home/stevejb/myhg/dpsolve/ootest/tests/ddw2011/profile_dir/dpclient.py:391(getValPolMatrices)

有没有办法在 Numpy 中获得更快的 kronecker 产品?看起来它不应该花这么长时间。

最佳答案

您当然可以查看np.kron 的源代码。它可以在 numpy/lib/shape_base.py 中找到,您可以查看是否可以进行改进或简化以提高效率。或者,您可以使用 Cython 或其他一些绑定(bind)到低级语言来编写自己的代码,以尝试获得更好的性能。

或者正如@matt 所建议的那样,像下面这样的东西本身可能会更快:

import numpy as np
nrows = 10
a = np.arange(100).reshape(10,10)
b = np.tile(a,nrows).reshape(nrows*a.shape[0],-1) # equiv to np.kron(a,np.ones((nrows,1)))

或:

b = np.repeat(a,nrows*np.ones(a.shape[0],np.int),axis=0)

时间:

In [80]: %timeit np.tile(a,nrows).reshape(nrows*a.shape[0],-1)
10000 loops, best of 3: 25.5 us per loop

In [81]: %timeit np.kron(a,np.ones((nrows,1)))
10000 loops, best of 3: 117 us per loop

In [91]: %timeit np.repeat(a,nrows*np.ones(a.shape[0],np.int),0)
100000 loops, best of 3: 12.8 us per loop

在上面的示例中,对大小数组使用 np.repeat 可以实现相当不错的 10 倍加速,这还算不错。

关于python - 加速 numpy kronecker 产品,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/7193870/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com