gpt4 book ai didi

python - 通过 meshgrid 查找成对的 numpy 数组的所有组合

转载 作者:太空宇宙 更新时间:2023-11-04 09:54:49 25 4
gpt4 key购买 nike

假设我有一个由值对组成的 numpy 数组。我想找到这些对的所有组合,而不会将它们拆开。特别是,我希望为此找到一个 numpy.meshgrid 解决方案。

想象一个构造如下的数组:

ab = np.array([[1,10], [2,20], [3,30], [4,40]])

那么我想要的输出是

>>> out: ([1,10], [2,20])
([1,10], [3,30])
([1,10], [4,40])
([2,20], [3,30])
([2,20], [4,40])
([3,30], [4,40])

输出可以是np.array,也可以是元组(之后我可以进行相应的转换)。请注意我的结果中如何省略重复项,忽略我的夫妇的顺序(如果 [[1,10], [2,20]] 已经存在,我不想要 [[2,20], [1,10]] 在我的输出中)。对于实际情况,ab 的大小为 30,000,因此速度是另一个问题。

这就是我首先尝试 meshgrid 的原因。对于单个值的简单情况,这很容易完成(但是,仍然有重复项):

a = np.array([1,2,3,4])
mesh = np.array(np.meshgrid(a,a)).T.reshape(-1,2)
>>> out: [[1 1]
[1 2]
[1 3]
[1 4]
[2 1]
[...]
[4 4]]

但对于我的双子,我的尝试

mesh = np.array(np.meshgrid(ab,ab)).T

给我

[[[ 1  1]
[ 1 10]
[ 1 2]
[ 1 20]
[ 1 3]
[ 1 30]
[ 1 4]
[ 1 40]]

[[10 1]
[10 10]
[10 2]
[10 20]
...
[40 3]
[40 30]
[40 4]
[40 40]]]

换句话说:meshgrid 打破了我的对。我认为解决方案就在附近,但我自己想不出。感谢任何帮助,谢谢!

最佳答案

不要认为 meshgrid 会起作用,因为它会创建所有可能的组合(稍后会过滤掉)。为了解决这个问题,可以提出两种方法。

方法 #1

我们可以获取那些没有重复的成对组合的行索引,然后简单地对行进行索引以获得所需的输出,就像这样 -

In [99]: r,c = np.triu_indices(len(ab),1)

In [100]: np.hstack(( ab[r], ab[c] ))
Out[100]:
array([[ 1, 10, 2, 20],
[ 1, 10, 3, 30],
[ 1, 10, 4, 40],
[ 2, 20, 3, 30],
[ 2, 20, 4, 40],
[ 3, 30, 4, 40]])

要获得所需的 3D 数组输出,请沿第二个轴堆叠 -

In [115]: np.stack(( ab[r], ab[c] ), axis=1)
Out[115]:
array([[[ 1, 10],
[ 2, 20]],

[[ 1, 10],
[ 3, 30]],

[[ 1, 10],
[ 4, 40]],

[[ 2, 20],
[ 3, 30]],

[[ 2, 20],
[ 4, 40]],

[[ 3, 30],
[ 4, 40]]])

作为函数:

def pairwise_combs1(ab):
r,c = np.triu_indices(len(ab),1)
return np.stack(( ab[r], ab[c] ), axis=1)

方法 #2 另一种使用切片数组初始化 的方法旨在提高内存效率和性能 -

def pairwise_combs2(ab):
n = len(ab)
N = n*(n-1)//2
out = np.empty((N,2,2),dtype=ab.dtype)
idx = np.concatenate(( [0], np.arange(n-1,0,-1).cumsum() ))
start, stop = idx[:-1], idx[1:]
for j,i in enumerate(range(n-1)):
out[start[j]:stop[j],0] = ab[j]
out[start[j]:stop[j],1] = ab[j+1:]
return out

运行时测试

In [166]: ab = np.random.randint(0,9,(1000,2))

In [167]: %timeit pairwise_combs1(ab)
10 loops, best of 3: 20 ms per loop

In [168]: %timeit pairwise_combs2(ab)
100 loops, best of 3: 6.25 ms per loop

In [169]: np.allclose(pairwise_combs1(ab), pairwise_combs2(ab))
Out[169]: True

关于python - 通过 meshgrid 查找成对的 numpy 数组的所有组合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46339926/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com