gpt4 book ai didi

python - Numpy:有效地找到行式公共(public)元素

转载 作者:太空狗 更新时间:2023-10-30 00:58:14 30 4
gpt4 key购买 nike

假设我们有两个具有相同行数的二维 numpy 数组 ab。进一步假设我们知道 ab 的每一行 i 最多有一个共同元素,尽管这个元素可能出现多次。我们怎样才能尽可能高效地找到这个元素?

一个例子:

import numpy as np

a = np.array([[1, 2, 3],
[2, 5, 2],
[5, 4, 4],
[2, 1, 3]])

b = np.array([[4, 5],
[3, 2],
[1, 5],
[0, 5]])

desiredResult = np.array([[np.nan],
[2],
[5],
[np.nan]])

通过沿第一个轴应用 intersect1d 很容易想出一个直接的实现:

from intertools import starmap

desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))

显然,使用 python 的内置集合操作甚至更快。将结果转换为所需的形式很容易。

但是,我需要一个尽可能高效的实现。因此,我不喜欢 starmap,因为我认为它需要对每一行进行 python 调用。我想要一个纯矢量化的选项,并且会很高兴,如果这甚至利用了我们的额外知识,即每行最多有一个公共(public)值。

有没有人知道我可以如何加快任务速度并更优雅地实现解决方案?我可以使用 C 代码或 cython,但编码工作不应太多。

最佳答案

方法 #1

这是一个基于 searchsorted2d 的向量化-

# Sort each row of a and b in-place
a.sort(1)
b.sort(1)

# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)

# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0

# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b

# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)

# Finally use np.where to choose between valid match
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)

方法 #2

基于 Numba 的内存效率 -

from numba import njit

@njit(parallel=True)
def numba_f1(a,b,out):
n,a_ncols = a.shape
b_ncols = b.shape[1]
for i in range(n):
for j in range(a_ncols):
for k in range(b_ncols):
m = a[i,j]==b[i,k]
if m:
break
if m:
out[i] = a[i,j]
break
return out

def find_first_common_elem_per_row(a,b):
out = np.full(len(a),np.nan)
numba_f1(a,b,out)
return out

方法 #3

这是另一个基于堆叠和排序的矢量化 -

r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)

方法 #4

对于一个优雅的方法,我们可以使用广播作为一种资源匮乏的方法-

m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)

关于python - Numpy:有效地找到行式公共(public)元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56895349/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com