gpt4 book ai didi

python - 更新以具有重复索引的叉积索引的 2D NumPy 数组

转载 作者:行者123 更新时间:2023-12-04 07:19:48 27 4
gpt4 key购买 nike

我对Increment Numpy multi-d array with repeated indices的版本感兴趣用交叉产品索引。
特别是,我想使用矩阵运算来执行以下代码完成的操作来加速它:

def get_s(image, grid_size):
W, H = image.shape
s = np.zeros((W, H))
for w in range(W):
for h in range(H):
i, j = int(w / grid_size), int(h / grid_size)
s[i, j] += image[w, h]
return s
我的想法是计算所有 (i, j)立即索引并使用 NumPy 的 ix_索引矩阵的方法 s :
def get_s(image, grid_size):
W, H = image.shape
s = np.zeros((W, H))
w_idx, h_idx = np.arange(W), np.arange(H)
x_idx, y_idx = np.trunc(w_idx / grid_size).astype(int), np.trunc(h_idx / grid_size).astype(int)
s[np.ix_(x_idx, y_idx)] += image
return s
用NumPy的例子更容易理解上面的代码:

Using ix_ one can quickly construct index arrays that will index the cross product. a[np.ix_([1,3],[2,5])] returns the array [[a[1,2] a[1,5]], [a[3,2] a[3,5]]].


就我而言,很可能会重复某些索引(例如 grid_size=2int(0 / grid_size) = int(1 / grid_size) )。这就是 Increment Numpy multi-d array with repeated indices问题来了。
如果索引重复,我想用图像值更新矩阵相同的次数。如果没有任何额外的循环(例如,压缩索引;但您基本上必须执行 s 和图像的索引的实际叉积),我无法解决此问题。

最佳答案

我不认为这是最好的方法,但这是一种方法。

import numpy as np

image = np.arange(9).reshape(3, 3)
s = np.zeros((5, 5))
x_idx, y_idx = np.meshgrid([0, 0, 2], [1, 1, 2])
# find unique destinations
idxs = np.stack((x_idx.flatten(), y_idx.flatten())).T
idxs_unique, counts = np.unique(idxs, axis = 0, return_counts = True)
# create mask for the source and sumthe source pixels headed to the same destination
idxs_repeated = idxs[None, :, :].repeat(len(idxs_unique), axis = 0)
image_mask = (idxs_repeated == idxs_unique[:, None, :]).all(-1)
pixel_sum = (image.flatten()[None, :]*image_mask).sum(-1)
# assign summed sources to destination
s[tuple(idxs_unique.T)] += pixel_sum
编辑 1:
如果遇到由内存限制引起的问题,您可以按照以下实现方式批量进行图像屏蔽和求和。我将批量大小设置为 10,但该参数可以设置为任何适用于您的机器的参数。
import numpy as np

image = np.arange(12).reshape(3, 4)
s = np.zeros((5, 5))
x_idx, y_idx = np.meshgrid([0, 0, 2], [1, 1, 2, 1])
idxs = np.stack((x_idx.flatten(), y_idx.flatten())).T
idxs_unique, counts = np.unique(idxs, axis = 0, return_counts = True)
batch_size = 10
pixel_sum = []
for i in range(len(unique_idxs)//batch_size + ((len(unique_idxs)%batch_size)!=0)):
batch = idxs_unique[i*batch_size:(i+1)*batch_size, None, :]
idxs_repeated = idxs[None, :, :].repeat(len(batch), axis = 0)
image_mask = (idxs_repeated == idxs_unique[i*batch_size:(i+1)*batch_size, None, :]).all(-1)
pixel_sum.append((image.flatten()[None, :]*image_mask).sum(-1))
pixel_sum = np.concatenate(pixel_sum)
s[tuple(idxs_unique.T)] += pixel_sum
编辑 2:
如果您使用 numba,OP 的方法似乎更快。
import numpy as np
from numba import jit

@jit(nopython=True)
def get_s(image, grid_size):
W, H = image.shape
s = np.zeros((W, H))
for w in range(W):
for h in range(H):
i, j = int(w / grid_size), int(h / grid_size)
s[i, j] += image[w, h]
return s

def get_s_vec(image, grid_size, batch_size = 10):
W, H = image.shape
s = np.zeros((W, H))
w_idx, h_idx = np.arange(W), np.arange(H)
x_idx, y_idx = np.trunc(w_idx / grid_size).astype(int), np.trunc(h_idx / grid_size).astype(int)
y_idx, x_idx = np.meshgrid(y_idx, x_idx)
idxs = np.stack((x_idx.flatten(), y_idx.flatten())).T
idxs_unique, counts = np.unique(idxs, axis = 0, return_counts = True)
pixel_sum = []
for i in range(len(unique_idxs)//batch_size + ((len(unique_idxs)%batch_size)!=0)):
batch = idxs_unique[i*batch_size:(i+1)*batch_size, None, :]
idxs_repeated = idxs[None, :, :].repeat(len(batch), axis = 0)
image_mask = (idxs_repeated == idxs_unique[i*batch_size:(i+1)*batch_size, None, :]).all(-1)
pixel_sum.append((image.flatten()[None, :]*image_mask).sum(-1))
pixel_sum = np.concatenate(pixel_sum)
s[tuple(idxs_unique.T)] += pixel_sum
return s


print(f'loop result = {get_s(image, 2)}')
print(f'vector result = {get_s_vec(image, 2)}')

%timeit get_s(image, 2)
%timeit get_s_vec(image, 2)
输出:
loop result = [[10. 18.  0.  0.]
[17. 21. 0. 0.]
[ 0. 0. 0. 0.]]
vector result = [[10. 18. 0. 0.]
[17. 21. 0. 0.]
[ 0. 0. 0. 0.]]
The slowest run took 15.00 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 5: 751 ns per loop
1000 loops, best of 5: 195 µs per loop

关于python - 更新以具有重复索引的叉积索引的 2D NumPy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68573992/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com