gpt4 book ai didi

python - 结果稀疏性已知时的稀疏矩阵乘法(在 python|scipy|cython 中)

转载 作者:太空宇宙 更新时间:2023-11-04 10:33:26 25 4
gpt4 key购买 nike

假设我们想为给定的稀疏矩阵 A、B 计算 C=A*B,但对 C 的一个非常小的条目子集感兴趣,由索引对列表表示:
行数=[i1, i2, i3 ... ]
cols=[j1, j2, j3 ... ]
A 和 B 都很大(比如 50Kx50K),但非常稀疏(<1% 的条目非零)。

我们如何计算乘法的这个子集?

这是一个运行速度非常慢的简单实现:

def naive(A, B, rows, cols):
N = len(rows)
vals = []
for n in xrange(N):
v = A.getrow(rows[n]) * B.getcol(cols[n])
vals.append(v[0, 0])

R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(A.shape[0], B.shape[1]), dtype=np.float64)
return R

即使对于小矩阵,这也很糟糕:

import scipy.sparse as sps
import numpy as np
D = 1000

A = np.random.randn(D, D)
A[np.abs(A) > 0.1] = 0
A = sps.csr_matrix(A)
B = np.random.randn(D, D)
B[np.abs(B) > 0.1] = 0
B = sps.csr_matrix(B)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
naive(A, B, rows, cols)

在我的机器上,naive() 在 1 分钟后完成,大部分工作都花在构造行/列上(在 getrow()、getcol() 中)。
当然,将这个(非常小的)示例转换为密集矩阵,计算大约需要 100 毫秒:

A0 = np.array(A.todense())
B0 = np.array(B.todense())
X0 = np.array(X.todense())
A0.dot(B0) * X0

关于如何高效计算这种矩阵乘法有什么想法吗?

最佳答案

稀疏矩阵的格式在这里很重要。您始终需要 A 行和 B 列。因此,将 A 存储为 csr 并将 B 存储为 csc 来摆脱 getrow/getcol 开销。不幸的是,这只是故事的一小部分。

最佳解决方案在很大程度上取决于稀疏矩阵的结构(很多稀疏列/行等),但您可以尝试基于字典和集合的解决方案。对于每一行的矩阵 A,保留以下内容:

  • 该行上所有非零列索引的集合
  • 以非零索引为键,对应的非零值为值的字典

对于矩阵B,每一列都保留了类似的字典和集合。

要计算乘法结果中的元素 (M, N),将 A 的 M 行与 B 的 N 列相乘。乘法:

  • 找到非零集的交集
  • 计算非零元素(即上面的交集)的乘法和

在大多数情况下,这应该非常快,因为在稀疏矩阵中,集合交集通常非常小。

部分代码:

class rowarray():
def __init__(self, arr):
self.rows = []
for row in arr:
nonzeros = np.nonzero(row)[0]
nzvalues = { i: row[i] for i in nonzeros }
self.rows.append((set(nonzeros), nzvalues))

def __getitem__(self, key):
return self.rows[key]

def __len__(self):
return len(self.rows)


class colarray(rowarray):
def __init__(self, arr):
rowarray.__init__(self, arr.T)


def maybe_less_naive(A, B, rows, cols):
N = len(rows)
vals = []
for n in xrange(N):
nz1,v1 = A[rows[n]]
nz2,v2 = B[cols[n]]
# list of common non-zeros
nz = nz1.intersection(nz2)
# sum of non-zeros
vals.append(sum([ v1[i]*v2[i] for i in nz]))

R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(len(A), len(B)), dtype=np.float64)
return R

D = 1000

Ap = np.random.randn(D, D)
Ap[np.abs(Ap) > 0.1] = 0
A = rowarray(Ap)
Bp = np.random.randn(D, D)
Bp[np.abs(Bp) > 0.1] = 0
B = colarray(Bp)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
maybe_less_naive(A, B, rows, cols)

这样效率更高一些,乘法测试大约需要 2 秒(80 000 个元素)。结果似乎基本相同。


对性能的一些评论。

对每个输出元素执行两个操作:

  • 设置路口
  • 乘法

集合交集的复杂度应为 O(min(m,n)),其中 m 和 n 是每个操作数中非零的个数。这是矩阵大小的不变性,只有每行/列的平均非零数很重要。

乘法(和字典查找)的次数取决于在上面的交集中找到的非零数。

如果两个矩阵都以概率(密度)p 随机分布非零值,并且行/列长度为 n,则:

  • 设置交集:O(np)
  • 字典查找,乘法:O(np^2)

这表明对于真正稀疏的矩阵,找到交点是关键点。这也可以通过分析来验证;大部分时间都花在计算交点上。

当这反射(reflect)到现实世界中时,我们似乎花费大约 20 微秒来处理一行/一列 80 个非零值。这并不是快得让人眼花缭乱,而且代码当然可以做得更快。 Cython 可能是一种解决方案,但这可能是 Python 不是最佳解决方案的问题之一。当用 C 编写时,排序整数的简单线性匹配(合并排序类型算法)应该至少快一个数量级。

需要注意的一件重要事情是,该算法可以一次针对多个元素并行完成。无需满足于单个线程,因为只要一个线程处理一个输出点,计算就是独立的。

关于python - 结果稀疏性已知时的稀疏矩阵乘法(在 python|scipy|cython 中),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/25005130/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com