gpt4 book ai didi

python - 检查高维数组的有效方法在 Python 中的两个 ndarray 中重叠

转载 作者:太空宇宙 更新时间:2023-11-03 12:31:52 25 4
gpt4 key购买 nike

例如,我有两个 ndarrays,train_dataset 的形状是 (10000, 28, 28)val_dateset 的形状是(2000, 28, 28)

除了使用迭代之外,还有什么有效的方法可以使用 numpy 数组函数来查找两个 ndarray 之间的重叠吗?

最佳答案

我从 Jaime's excellent answer here 学到的一个技巧是使用 np.void dtype 以便将输入数组中的每一行视为单个元素。这允许您将它们视为一维数组,然后可以将其传递给 np.in1d或另一个 set routines .

import numpy as np

def find_overlap(A, B):

if not A.dtype == B.dtype:
raise TypeError("A and B must have the same dtype")
if not A.shape[1:] == B.shape[1:]:
raise ValueError("the shapes of A and B must be identical apart from "
"the row dimension")

# reshape A and B to 2D arrays. force a copy if neccessary in order to
# ensure that they are C-contiguous.
A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
B = np.ascontiguousarray(B.reshape(B.shape[0], -1))

# void type that views each row in A and B as a single item
t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))

# use in1d to find rows in A that are also in B
return np.in1d(A.view(t), B.view(t))

例如:

gen = np.random.RandomState(0)

A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]

A_in_B = find_overlap(A, B)

print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True

此方法比 Divakar 的方法更节省内存,因为它不需要向外广播 (m, n, ...) bool 数组。事实上,如果 AB 是行优先的,则根本不需要复制。


为了比较,我稍微调整了 Divakar 和 B.M. 的解决方案。

def divakar(A, B):
A.shape = A.shape[0], -1
B.shape = B.shape[0], -1
return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
ma = np.frombuffer(np.ascontiguousarray(A), t)
mb = np.frombuffer(np.ascontiguousarray(B), t)
return (mb[:, None] == ma).any(0)

基准:

In [1]: na = 1000; nb = 200; rowshape = 28, 28

In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
....:
1 loops, best of 3: 244 ms per loop

In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
100 loops, best of 3: 2.81 ms per loop

In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
100 loops, best of 3: 15 ms per loop

如您所见,对于较小的 n,B.M. 的解决方案略快于我的解决方案,但 np.in1d 的扩展性优于测试所有元素的相等性(O(n log n) 而不是 O(n²) 复杂度。

In [5]: na = 10000; nb = 2000; rowshape = 28, 28

In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
1 loops, best of 3: 271 ms per loop

In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
10 loops, best of 3: 123 ms per loop

对于这种大小的阵列,Divakar 的解决方案在我的笔记本电脑上难以处理,因为它需要生成一个 15GB 的中间阵列,而我只有 8GB RAM。

关于python - 检查高维数组的有效方法在 Python 中的两个 ndarray 中重叠,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34980550/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com