gpt4 book ai didi

python - 合并第 i^{th} 轴之前和之后的轴

转载 作者:行者123 更新时间:2023-11-30 22:25:11 25 4
gpt4 key购买 nike

我想 reshape 一个 numpy 数组 arr形状为(before, at, after)对于任何人axisarr 。如何更快地做到这一点?

轴已标准化:0 <= axis < arr.ndim

程序:

import numpy as np
def f(arr, axis):
shape = arr.shape
before = int(np.product(shape[:axis]))
at = shape[axis]
return arr.reshape(before, at, -1)

测试:

a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
print(f(a, 2).shape)

结果:

(6, 4, 5)

最佳答案

shape是一个元组,想要的结果也是一个元组。与数组之间的转换以使用 np.prod 或其他数组函数需要时间。因此,如果我们可以使用简单的 Python 代码做同样的事情,我们可能会节省时间。

例如形状:

In [309]: shape
Out[309]: (2, 3, 4, 5)
In [310]: np.prod(shape)
Out[310]: 120
In [311]: functools.reduce(operator.mul,shape)
Out[311]: 120

In [312]: timeit np.prod(shape)
13.6 µs ± 30.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [313]: timeit functools.reduce(operator.mul,shape)
647 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

Python 版本明显更快。我必须导入 functoolsoperator 才能获得与 sum (Python3) 等价的乘法。

或者获取新的形状元组:

In [314]: axis=2
In [315]: (functools.reduce(operator.mul,shape[:axis]),shape[axis],-1)
Out[315]: (6, 4, -1)
In [316]: timeit (functools.reduce(operator.mul,shape[:axis]),shape[axis],-1)
739 ns ± 30.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

比较建议的reduceat:

In [318]: tuple(np.multiply.reduceat(shape, (0, axis, axis+1)))
Out[318]: (6, 4, 5)
In [319]: timeit tuple(np.multiply.reduceat(shape, (0, axis, axis+1)))
11.3 µs ± 21.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

关于python - 合并第 i^{th} 轴之前和之后的轴,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47604047/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com