gpt4 book ai didi

python - 每行的 Bin 元素 - NumPy 的矢量化 2D Bincount

转载 作者:太空狗 更新时间:2023-10-29 22:17:50 24 4
gpt4 key购买 nike

我有一个包含整数值的 NumPy 数组。矩阵的值范围从 0 到矩阵中的最大元素(换句话说,从 0 到最大数据元素的所有数字都出现在其中)。我需要构建有效(有效意味着快速全矢量化解决方案)来搜索每行中的元素数量并根据矩阵值对它们进行编码。

我找不到类似的问题,也找不到以某种方式帮助解决此问题的问题。

所以如果我在输入中有这个数据:

# shape is (N0=4, m0=4) 
1 1 0 4
2 4 2 1
1 2 3 5
4 4 4 1

期望的输出是:

# shape(N=N0, m=data.max()+1):
1 2 0 0 1 0
0 1 2 0 1 0
0 1 1 1 0 1
0 1 0 0 3 0

我知道如何通过简单地计算 data 每一行中的唯一值来解决这个问题,逐个迭代,然后合并结果,同时考虑 data大批。

虽然使用 NumPy 对其进行矢量化,但关键问题是逐个搜索每个数字的速度很慢,并且假设存在大量唯一数字,这不是有效的解决方案。通常 N 和唯一数字计数都相当大(顺便说一下,N 似乎比唯一数字计数大)。

有人有好主意吗?)

最佳答案

嗯,这基本上就是 np.bincount 的作用。处理 1D 数组。但是,我们需要在每一行上迭代地使用它(简单地考虑一下)。为了使其矢量化,我们可以将每一行偏移那个最大数。这个想法是为每一行设置不同的 bin,这样它们就不会受到具有相同编号的其他行元素的影响。

因此,实现将是 -

# Vectorized solution
def bincount2D_vectorized(a):
N = a.max()+1
a_offs = a + np.arange(a.shape[0])[:,None]*N
return np.bincount(a_offs.ravel(), minlength=a.shape[0]*N).reshape(-1,N)

sample 运行-

In [189]: a
Out[189]:
array([[1, 1, 0, 4],
[2, 4, 2, 1],
[1, 2, 3, 5],
[4, 4, 4, 1]])

In [190]: bincount2D_vectorized(a)
Out[190]:
array([[1, 2, 0, 0, 1, 0],
[0, 1, 2, 0, 1, 0],
[0, 1, 1, 1, 0, 1],
[0, 1, 0, 0, 3, 0]])

Numba 调整

我们可以引入numba进一步加速。现在,numba 允许进行一些微调。

  • 首先,它允许 JIT 编译。

  • 另外,最近他们推出了实验性的 parallel自动并行化已知具有并行语义的函数中的操作。

  • 最后的调整是使用 prange作为 range 的替代品。文档指出这并行运行循环,类似于 OpenMP 并行 for 循环和 Cython 的 prange。 prange 在更大的数据集上表现良好,这可能是因为设置并行工作所需的开销。

因此,通过这两个新的调整以及 njit对于非 Python 模式,我们将有三种变体 -

# Numba solutions
def bincount2D_numba(a, use_parallel=False, use_prange=False):
N = a.max()+1
m,n = a.shape
out = np.zeros((m,N),dtype=int)

# Choose fucntion based on args
func = bincount2D_numba_func0
if use_parallel:
if use_prange:
func = bincount2D_numba_func2
else:
func = bincount2D_numba_func1
# Run chosen function on input data and output
func(a, out, m, n)
return out

@njit
def bincount2D_numba_func0(a, out, m, n):
for i in range(m):
for j in range(n):
out[i,a[i,j]] += 1

@njit(parallel=True)
def bincount2D_numba_func1(a, out, m, n):
for i in range(m):
for j in range(n):
out[i,a[i,j]] += 1

@njit(parallel=True)
def bincount2D_numba_func2(a, out, m, n):
for i in prange(m):
for j in prange(n):
out[i,a[i,j]] += 1

为了稍后的完整性和测试,循环版本将是 -

# Loopy solution
def bincount2D_loopy(a):
N = a.max()+1
m,n = a.shape
out = np.zeros((m,N),dtype=int)
for i in range(m):
out[i] = np.bincount(a[i], minlength=N)
return out

运行时测试

案例#1:

In [312]: a = np.random.randint(0,100,(100,100))

In [313]: %timeit bincount2D_loopy(a)
...: %timeit bincount2D_vectorized(a)
...: %timeit bincount2D_numba(a, use_parallel=False, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=True)
10000 loops, best of 3: 115 µs per loop
10000 loops, best of 3: 36.7 µs per loop
10000 loops, best of 3: 22.6 µs per loop
10000 loops, best of 3: 22.7 µs per loop
10000 loops, best of 3: 39.9 µs per loop

案例#2:

In [316]: a = np.random.randint(0,100,(1000,1000))

In [317]: %timeit bincount2D_loopy(a)
...: %timeit bincount2D_vectorized(a)
...: %timeit bincount2D_numba(a, use_parallel=False, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=True)
100 loops, best of 3: 2.97 ms per loop
100 loops, best of 3: 3.54 ms per loop
1000 loops, best of 3: 1.83 ms per loop
100 loops, best of 3: 1.78 ms per loop
1000 loops, best of 3: 1.4 ms per loop

案例#3:

In [318]: a = np.random.randint(0,1000,(1000,1000))

In [319]: %timeit bincount2D_loopy(a)
...: %timeit bincount2D_vectorized(a)
...: %timeit bincount2D_numba(a, use_parallel=False, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=False)
...: %timeit bincount2D_numba(a, use_parallel=True, use_prange=True)
100 loops, best of 3: 4.01 ms per loop
100 loops, best of 3: 4.86 ms per loop
100 loops, best of 3: 3.21 ms per loop
100 loops, best of 3: 3.18 ms per loop
100 loops, best of 3: 2.45 ms per loop

似乎 numba 变体表现得非常好。从三种变体中选择一种将取决于输入数组形状参数,并且在某种程度上取决于其中唯一元素的数量。

关于python - 每行的 Bin 元素 - NumPy 的矢量化 2D Bincount,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46256279/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com