gpt4 book ai didi

python - 是否可以使用 Numpy 实现此版本的矩阵乘法?

转载 作者:行者123 更新时间:2023-12-02 05:50:39 25 4
gpt4 key购买 nike

我希望快速评估下面的函数,它在高层次上类似于矩阵乘法。对于大型矩阵,下面的实现比矩阵的 numpy 乘法慢几个数量级,这让我相信有更好的方法使用 numpy 来实现这一点。有没有办法使用 numpy 函数而不是 for 循环来实现这一点?我正在使用的矩阵每个维度都有 10K-100K 个元素,因此非常需要这种优化。

一种方法是使用 3D numpy 数组,但事实证明这太大而无法存储。我还研究了 np.vectorize 这似乎不合适。

非常感谢您的指导。

编辑:感谢大家的精彩见解和答案。非常感谢您的投入。将日志移出循环可以极大地提高运行时间,并且有趣的是,k 查找非常重要。如果可以的话,我有一个后续行动:如果内部循环表达式变为 C[i,j] += A[i,k] * np.log(A[i,k] ] + B[k,j])?日志可以像以前一样移出,但前提是 A[i,k] 取幂,这很昂贵并且消除了移出日志带来的 yield 。

import numpy as np
import numba
from numba import njit, prange

@numba.jit(fastmath=True, parallel=True)
def f(A, B):
   
    C = np.zeros((A.shape[0], B.shape[1]))

    for i in prange(A.shape[0]):
        for j in prange(B.shape[1]):
            for k in prange(A.shape[1]):
               
                C[i,j] += np.log(A[i,k] + B[k,j])
                #matrix mult. would be: C[i,j] += A[i,k] * B[k,j]

    return C

#A = np.random.rand(100000, 100000)
#B = np.random.rand(100000, 100000)
#f(A, B)

最佳答案

由于log(a) + log(b) == log(a * b),您可以通过用乘法代替加法并仅在以下位置执行对数来节省大量对数计算:结束,这应该会节省你很多时间。

import numpy as np
import numba as nb

@nb.njit(fastmath=True, parallel=True)
def f(A, B):
C = np.ones((A.shape[0], B.shape[1]), A.dtype)
for i in nb.prange(A.shape[0]):
for j in nb.prange(B.shape[1]):
# Accumulate product
for k in nb.prange(A.shape[1]):
C[i,j] *= (A[i,k] + B[k,j])
# Apply logarithm at the end
return np.log(C)

# For comparison
@nb.njit(fastmath=True, parallel=True)
def f_orig(A, B):
C = np.zeros((A.shape[0], B.shape[1]), A.dtype)
for i in nb.prange(A.shape[0]):
for j in nb.prange(B.shape[1]):
for k in nb.prange(A.shape[1]):
C[i,j] += np.log(A[i,k] + B[k,j])
return C

# Test
np.random.seed(0)
a, b = np.random.random((1000, 100)), np.random.random((100, 2000))
print(np.allclose(f(a, b), f_orig(a, b)))
# True
%timeit f(a, b)
# 36.2 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_orig(a, b)
# 296 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

关于python - 是否可以使用 Numpy 实现此版本的矩阵乘法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60169100/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com