gpt4 book ai didi

python - 如何获得比 numpy.dot 更快的代码用于矩阵乘法?

转载 作者:IT老高 更新时间:2023-10-28 21:16:03 27 4
gpt4 key购买 nike

这里 Matrix multiplication using hdf5我使用 hdf5 (pytables) 进行大矩阵乘法,但我很惊讶,因为使用 hdf5 它比使用普通 numpy.dot 并在 RAM 中存储矩阵更快,这种行为的原因是什么?

也许python中有一些更快的矩阵乘法函数,因为我仍然使用numpy.dot进行小块矩阵乘法。

这里有一些代码:

假设矩阵可以放入 RAM:在矩阵 10*1000 x 1000 上进行测试。

使用默认的 numpy(我认为没有 BLAS 库)。普通的 numpy 数组在 RAM 中:时间 9.48

如果 A、B 在 RAM 中,C 在磁盘上:时间 1.48

如果 A、B、C 在磁盘上:时间 372.25

如果我使用带有 MKL 的 numpy,结果是:0.15,0.45,43.5。

结果看起来很合理,但我仍然不明白为什么在第一种情况下 block 乘法更快(当我们将 A、B 存储在 RAM 中时)。

n_row=1000
n_col=1000
n_batch=10

def test_plain_numpy():
A=np.random.rand(n_row,n_col)# float by default?
B=np.random.rand(n_col,n_row)
t0= time.time()
res= np.dot(A,B)
print (time.time()-t0)

#A,B in RAM, C on disk
def test_hdf5_ram():
rows = n_row
cols = n_col
batches = n_batch

#using numpy array
A=np.random.rand(n_row,n_col)
B=np.random.rand(n_col,n_row)

#settings for all hdf5 files
atom = tables.Float32Atom() #if store uint8 less memory?
filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
Nchunk = 128 # ?
chunkshape = (Nchunk, Nchunk)
chunk_multiple = 1
block_size = chunk_multiple * Nchunk

#using hdf5
fileName_C = 'CArray_C.h5'
shape = (A.shape[0], B.shape[1])

h5f_C = tables.open_file(fileName_C, 'w')
C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)

sz= block_size

t0= time.time()
for i in range(0, A.shape[0], sz):
for j in range(0, B.shape[1], sz):
for k in range(0, A.shape[1], sz):
C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
print (time.time()-t0)

h5f_C.close()
def test_hdf5_disk():
rows = n_row
cols = n_col
batches = n_batch

#settings for all hdf5 files
atom = tables.Float32Atom() #if store uint8 less memory?
filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
Nchunk = 128 # ?
chunkshape = (Nchunk, Nchunk)
chunk_multiple = 1
block_size = chunk_multiple * Nchunk


fileName_A = 'carray_A.h5'
shape_A = (n_row*n_batch, n_col) # predefined size

h5f_A = tables.open_file(fileName_A, 'w')
A = h5f_A.create_carray(h5f_A.root, 'CArray', atom, shape_A, chunkshape=chunkshape, filters=filters)

for i in range(batches):
data = np.random.rand(n_row, n_col)
A[i*n_row:(i+1)*n_row]= data[:]

rows = n_col
cols = n_row
batches = n_batch

fileName_B = 'carray_B.h5'
shape_B = (rows, cols*batches) # predefined size

h5f_B = tables.open_file(fileName_B, 'w')
B = h5f_B.create_carray(h5f_B.root, 'CArray', atom, shape_B, chunkshape=chunkshape, filters=filters)

sz= rows/batches
for i in range(batches):
data = np.random.rand(sz, cols*batches)
B[i*sz:(i+1)*sz]= data[:]


fileName_C = 'CArray_C.h5'
shape = (A.shape[0], B.shape[1])

h5f_C = tables.open_file(fileName_C, 'w')
C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)

sz= block_size

t0= time.time()
for i in range(0, A.shape[0], sz):
for j in range(0, B.shape[1], sz):
for k in range(0, A.shape[1], sz):
C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
print (time.time()-t0)

h5f_A.close()
h5f_B.close()
h5f_C.close()

最佳答案

np.dot 分派(dispatch)到 BLAS

  • NumPy 已编译为使用 BLAS,
  • BLAS 实现在运行时可用,
  • 您的数据具有 float32float64complex32complex64
  • 数据在内存中适当对齐。

否则,它默认使用自己的慢速矩阵乘法例程。

描述了检查您的 BLAS 链接 here .简而言之,检查您的 NumPy 安装中是否有文件 _dotblas.so 或类似文件。如果有,请检查它链接到哪个 BLAS 库;引用 BLAS 很慢,ATLAS 很快,OpenBLAS 和供应商特定版本(如英特尔 MKL)甚至更快。注意多线程 BLAS 实现,因为它们 don't play nicely使用 Python 的 multiprocessing

接下来,通过检查数组的 flags 检查数据对齐情况。在 1.7.2 之前的 NumPy 版本中,np.dot 的两个参数都应该是 C 顺序的。在 NumPy >= 1.7.2 中,这不再重要,因为已经引入了 Fortran 数组的特殊情况。

>>> X = np.random.randn(10, 4)
>>> Y = np.random.randn(7, 4).T
>>> X.flags
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
>>> Y.flags
C_CONTIGUOUS : False
F_CONTIGUOUS : True
OWNDATA : False
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False

如果你的 NumPy 没有与 BLAS 链接,要么(简单)重新安装它,要么(硬)使用 SciPy 中的 BLAS gemm(广义矩阵乘法)函数:

>>> from scipy.linalg import get_blas_funcs
>>> gemm = get_blas_funcs("gemm", [X, Y])
>>> np.all(gemm(1, X, Y) == np.dot(X, Y))
True

这看起来很简单,但它几乎没有任何错误检查,所以你必须真正知道你在做什么。

关于python - 如何获得比 numpy.dot 更快的代码用于矩阵乘法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19839539/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com