gpt4 book ai didi

python - 在矩阵的行或列中保留 n 个最大元素的可靠方法

转载 作者:行者123 更新时间:2023-12-04 08:27:18 25 4
gpt4 key购买 nike

我想做一个 sparse来自密集矩阵的矩阵,使得在每一行或每一列中只有 n -最大的元素被保留。我执行以下操作:

def sparsify(K, min_nnz = 5):
'''
This function eliminates the elements which are smaller that the maximal element in the matrix,

Parameters
----------
K : ndarray
K - the input matrix
min_nnz:
the minimal number of elements in row or column to be preserved

'''
cond = np.bitwise_or(K >= -np.partition(-K, min_nnz - 1, axis = 1)[:, min_nnz - 1][:, None],
K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :][None, :])

return spsp.csr_matrix(np.where(cond, K, 0))
这种方法按预期工作,但似乎不是最有效和最健壮的方法。你会推荐什么来做一个更好的方法?
用法示例:
A = np.random.rand(10, 10)
A_sp = sparsify(A, min_nnz = 3)

最佳答案

您可以使用 coo_matrix 仅使用您需要的值来构建,而不是制作另一个密集矩阵:

return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
至于其余的,您可以将第二维短路,但您节省的时间将完全取决于您的输入
def sparsify(K, min_nnz = 5):

'''
This function eliminates the elements which are smaller that the maximal element in the matrix,

Parameters
----------
K : ndarray
K - the input matrix
min_nnz:
the minimal number of elements in row or column to be preserved

'''
cond = K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :]
mask = cond.sum(1) < min_nnz
cond[mask] = np.bitwise_or(cond[mask],
K[mask] >= -np.partition(-K[mask],
min_nnz - 1,
axis = 1)[:, min_nnz - 1][:, None])

return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
测试:
sparsify(A)
Out[]:
<10x10 sparse matrix of type '<class 'numpy.float64'>'
with 58 stored elements in COOrdinate format>

sparsify(A).A
Out[]:
array([[0. , 0. , 0.61362248, 0. , 0.73648987,
0.64561856, 0.40727807, 0.61674005, 0.53533315, 0. ],
[0.8888361 , 0.64548039, 0.94659603, 0.78474203, 0. ,
0. , 0.78809603, 0.88938798, 0. , 0.37631541],
[0.69356682, 0. , 0. , 0. , 0. ,
0.7386594 , 0.71687659, 0.67750768, 0.58002451, 0. ],
[0.67241433, 0.71923718, 0.95888737, 0. , 0. ,
0. , 0.82773085, 0.69788448, 0.63736915, 0.4263064 ],
[0. , 0.65831794, 0. , 0. , 0.59850093,
0. , 0. , 0.61913869, 0.65024867, 0.50860294],
[0.75522891, 0. , 0.93342402, 0.8284258 , 0.64471939,
0.6990814 , 0. , 0. , 0. , 0.32940821],
[0. , 0.88458635, 0.62460096, 0.60412265, 0.66969674,
0. , 0.40318741, 0. , 0. , 0.44116059],
[0. , 0. , 0.500971 , 0.92291245, 0. ,
0.8862903 , 0. , 0.375885 , 0.49473635, 0. ],
[0.86920647, 0.85157893, 0.89883006, 0. , 0.68427193,
0.91195162, 0. , 0. , 0.94762875, 0. ],
[0. , 0.6435456 , 0. , 0.70551006, 0. ,
0.8075527 , 0. , 0.9421039 , 0.91096934, 0. ]])

sparsify(A).A.astype(bool).sum(0)
Out[]: array([5, 6, 7, 5, 5, 6, 5, 7, 7, 5])

sparsify(A).A.astype(bool).sum(1)
Out[]: array([6, 7, 5, 7, 5, 6, 6, 5, 6, 5])

关于python - 在矩阵的行或列中保留 n 个最大元素的可靠方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65198203/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com