gpt4 book ai didi

python - 切片 jax.numpy 数组时性能下降

转载 作者:行者123 更新时间:2023-12-05 03:47:47 25 4
gpt4 key购买 nike

在尝试对大型数组进行 SVD 压缩时,我遇到了一些我在 Jax 中无法理解的行为。这是示例代码:

@jit 
def jax_compress(L):
U, S, _ = jsc.linalg.svd(L,
full_matrices = False,
lapack_driver = 'gesvd',
check_finite=False,
overwrite_a=True)

maxS=jnp.max(S)
chi = jnp.sum(S/maxS>1E-1)

return chi, jnp.asarray(U)

当考虑这段代码时,Jax/jit 比 SciPy 有巨大的性能提升,但最终我想减少 U 的维数,我通过将它包装在函数中来实现:

def jax_process(A):

chi, U = jax_compress(A)

return U[:,0:chi]

就计算时间而言,这一步的成本高得令人难以置信,比 SciPy 等价物还要多,从这个比较中可以看出:

benchmark of jax and scipy

sc_compresssc_process 是上述 jax 代码的 SciPy 等效项。如您所见,在 SciPy 中对数组进行切片几乎不需要任何成本,但在应用于命中函数的输出时却非常昂贵。有没有人对这种行为有所了解?

最佳答案

我对 JAX 和 PyTorch 之间的切片速度进行了类似的比较。 dynamic_slice 比常规切片快得多,但仍然比 torch 中的等效切片慢得多。由于我是 JAX 的新手,我不确定原因是什么,但这可能与复制和引用有关,因为 JAX 数组是不可变的。

JAX(没有@jit)

key = random.PRNGKey(0)
j = random.normal(key, (32, 2, 1024, 1024, 3))
%timeit j[..., 100:600, 100:600, :].block_until_ready()
%timeit dynamic_slice(j, [0, 0, 100, 100, 0], [32, 2, 500, 500, 3]).block_until_ready()
2.78 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
993 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

torch

t = torch.randn((32, 2, 1024, 1024, 3)).cuda()

%%timeit
t[..., 100:600, 100:600, :]
torch.cuda.synchronize()
7.63 µs ± 22.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

关于python - 切片 jax.numpy 数组时性能下降,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64750139/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com