gpt4 book ai didi

python - 使用 NumPy 快速求出所有行对的差异

转载 作者:行者123 更新时间:2023-12-01 04:55:02 26 4
gpt4 key购买 nike

我使用的算法要求每个示例都有一个矩阵,例如 Xi ,它是 ai x b,并且对于每个 O(n ^2) 对示例,我找到每一行之间的差异 Xiu - Xjv,然后将外积相加 sum_u sum_v np.outer(Xiu - Xjv, Xiu - Xjv )

不幸的是,这个内部双和相当慢,并且导致大型数据集上的运行时间失控。现在我只是使用 for 循环来做到这一点。有没有一些Pythonic的方法来加速这个内部操作?我一直在想一个办法,但没有成功。

为了澄清,对于每个 n 示例,都有一个维度为 ai x b 的矩阵 Xi,其中 ai > 每个示例都不同。对于每对 (Xi, Xj) 我想遍历两个矩阵之间的所有 O(ai * bi) 对行并找到 Xiu - Xjv ,取其与其自身的外积np.outer(Xiu - Xjv, Xiu - Xjv),最后对所有这些外积求和。

例如:假设 D 是 [[1,2],[3,4]],我们只是针对这两个矩阵使用它。

那么,一步可能是 np.outer(D[0] - D[1], D[0] - D[1]) ,这将是矩阵 [4 ,4],[4,4].

简单来说,(0,0) 和 (1,1) 只是 0 矩阵,而 (0,1) 和 (1,0) 都是 4 矩阵,因此对的所有四个外积的最终总和将是[[8,8],[8,8]]

最佳答案

好吧,这很有趣。我仍然情不自禁地认为这一切都可以通过对 numpy.tensordot 的一次巧妙调用来完成,但无论如何,这似乎已经消除了所有 Python 级别的循环:

import numpy

def slow( a, b=None ):

if b is None: b = a
a = numpy.asmatrix( a )
b = numpy.asmatrix( b )

out = 0.0
for ai in a:
for bj in b:
dij = bj - ai
out += numpy.outer( dij, dij )
return out

def opsum( a, b=None ):

if b is None: b = a
a = numpy.asmatrix( a )
b = numpy.asmatrix( b )

RA, CA = a.shape
RB, CB = b.shape
if CA != CB: raise ValueError( "input matrices should have the same number of columns" )

out = -numpy.outer( a.sum( axis=0 ), b.sum( axis=0 ) );
out += out.T
out += RB * ( a.T * a )
out += RA * ( b.T * b )
return out

def test( a, b=None ):
print( "ground truth:" )
print( slow( a, b ) )
print( "optimized:" )
print( opsum( a, b ) )
print( "max abs discrepancy:" )
print( numpy.abs( opsum( a, b ) - slow( a, b ) ).max() )
print( "" )

# OP example
test( [[1,2], [3,4]] )

# non-symmetric example
a = [ [ 1, 2, 3 ], [-4, 5, 6 ], [7, -8, 9 ], [ 10, 11, -12 ] ]
a = numpy.matrix( a, dtype=float )
b = a[ ::2, ::-1 ] + 15
test( a, b )

# non-integer example
test( numpy.random.randn( *a.shape ), numpy.random.randn( *b.shape ) )

使用这个(相当任意的)示例输入,opsum 的时间(在 IPython 中使用 timeit opsum(a,b) 测量)看起来只有 3 倍左右 - 5 比好。但当然,它的扩展性要好得多:将数据点的数量扩大 100 倍,将特征数量扩大 10 倍,然后我们已经在研究大约速度提高 10,000 倍。

关于python - 使用 NumPy 快速求出所有行对的差异,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27627896/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com