gpt4 book ai didi

python - 沿 NumPy 数组的轴计算唯一元素

转载 作者:太空狗 更新时间:2023-10-30 01:05:12 24 4
gpt4 key购买 nike

我有一个像这样的三维数组

A=np.array([[[1,1],[1,0]],[[1,2],[1,0]],[[1,0],[0,0]]])

现在我想获得一个在给定位置具有非零值的数组,前提是该位置仅出现唯一的非零值(或零)。如果在该位置仅出现零或多个非零值,则它应该为零。对于上面的例子,我想

[[1,0],[1,0]]

自从

  • A[:,0,0]中只有1
  • A[:,0,1]中有012,所以更多大于一个非零值
  • A[:,1,0]中有01,所以保留1
  • A[:,1,1]中只有0

我可以通过 np.count_nonzero(A, axis=0) 找到有多少非零元素,但我想保留 1 2,即使有几个。我查看了 np.unique,但它似乎不支持我想做的事情。

理想情况下,我想要一个像 np.count_unique(A, axis=0) 这样的函数,它会返回一个原始形状的数组,例如[[1, 3],[2, 1]],所以我可以检查是否出现 3 个或更多个,然后忽略那个位置。


我所能想出的只是一个列表理解迭代我想要获得的

[[len(np.unique(A[:, i, j])) for j in range(A.shape[2])] for i in range(A.shape[1])]

还有其他想法吗?

最佳答案

对于第二个任务,您可以使用 np.diff 保持在 numpy 级别。

def diffcount(A):
B=A.copy()
B.sort(axis=0)
C=np.diff(B,axis=0)>0
D=C.sum(axis=0)+1
return D

# [[1 3]
# [2 1]]

在大数组上似乎更快一些:

In [62]: A=np.random.randint(0,100,(100,100,100))

In [63]: %timeit diffcount(A)
46.8 ms ± 769 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [64]: timeit [[len(np.unique(A[:, i, j])) for j in range(A.shape[2])]\
for i in range(A.shape[1])]
149 ms ± 700 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

最后统计unique比排序简单,一个ln(A.shape[0])因子就可以赢了。

赢得这个因素的一种方法是使用集合机制:

In [81]: %timeit np.apply_along_axis(lambda a:len(set(a)),axis=0,A) 
183 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

不幸的是,这并没有更快。

另一种方法是手动完成:

def countunique(A,Amax):
res=np.empty(A.shape[1:],A.dtype)
c=np.empty(Amax+1,A.dtype)
for i in range(A.shape[1]):
for j in range(A.shape[2]):
T=A[:,i,j]
for k in range(c.size): c[k]=0
for x in T:
c[x]=1
res[i,j]= c.sum()
return res

在 python 级别:

In [70]: %timeit countunique(A,100)
429 ms ± 18.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

这对于纯 Python 方法来说还算不错。然后只需使用 numba 将此代码移动到低级别:

import numba    
countunique2=numba.jit(countunique)

In [71]: %timeit countunique2(A,100)
3.63 ms ± 70.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这将很难改进很多。

关于python - 沿 NumPy 数组的轴计算唯一元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46893369/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com