gpt4 book ai didi

python - Numpy intersect1d 与矩阵作为元素的数组

转载 作者:太空宇宙 更新时间:2023-11-03 14:57:25 27 4
gpt4 key购买 nike

我有两个数组,一个是 (200000, 28, 28) 的形状,另一个是 (10000, 28, 28) 的形状,所以实际上是两个以矩阵为元素的数组。现在我想计算并获取在两个数组中重叠的所有元素(格式为 (N, 28, 28))。对于普通的 for 循环,它会变慢,所以我尝试使用 numpys intersect1d 方法,但我不知道如何将它应用于这种类型的数组。

最佳答案

使用来自 this question about unique rows 的方法

def intersect_along_first_axis(a, b):
# check that casting to void will create equal size elements
assert a.shape[1:] == b.shape[1:]
assert a.dtype == b.dtype

# compute dtypes
void_dt = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
orig_dt = np.dtype((a.dtype, a.shape[1:]))

# convert to 1d void arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
a_void = a.reshape(a.shape[0], -1).view(void_dt)
b_void = b.reshape(b.shape[0], -1).view(void_dt)

# intersect, then convert back
return np.intersect1d(b_void, a_void).view(orig_dt)

请注意,使用 void 对 float 不安全,因为它会导致 -0 不等于 0

关于python - Numpy intersect1d 与矩阵作为元素的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41416626/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com