gpt4 book ai didi

python - 用另一个索引多维数组的多个维度 - NumPy/Python

转载 作者:太空宇宙 更新时间:2023-11-03 15:04:27 24 4
gpt4 key购买 nike

假设我有以下形式的张量:

import numpy as np
a = np.array([ [[1,2],
[3,4]],
[[5,6],
[7,3]]
])
# a.shape : (2,2,2) is a tensor containing 2x2 matrices
indices = np.argmax(a, axis=2)
#print indices
for mat in a:
max_i = np.argmax(mat,axis=1)
# Not really working I would like to
# change 4 in the first matrix to -1
# and 3 in the last to -1
mat[max_i] = -1

print a

现在我想做的是使用索引作为 a 上的掩码,用 say -1 替换每个最大元素。有没有一种 NumPy 的方法可以做到这一点?到目前为止,我所知道的只是使用 for 循环。

最佳答案

这是使用 linear indexing 的一种方法3D -

m,n,r = a.shape
offset = n*r*np.arange(m)[:,None] + r*np.arange(n)
np.put(a,indices + offset,-1)

sample 运行-

In [92]: a
Out[92]:
array([[[28, 59, 26, 70],
[57, 28, 71, 49],
[33, 6, 10, 90]],

[[24, 16, 83, 67],
[96, 16, 72, 56],
[74, 4, 71, 81]]])

In [93]: indices = np.argmax(a, axis=2)

In [94]: m,n,r = a.shape
...: offset = n*r*np.arange(m)[:,None] + r*np.arange(n)
...: np.put(a,indices + offset,-1)
...:

In [95]: a
Out[95]:
array([[[28, 59, 26, -1],
[57, 28, -1, 49],
[33, 6, 10, -1]],

[[24, 16, -1, 67],
[-1, 16, 72, 56],
[74, 4, 71, -1]]])

这里是另一种使用线性索引的方法,但是是2D -

m,n,r = a.shape
a.reshape(-1,r)[np.arange(m*n),indices.ravel()] = -1

运行时测试和验证输出 -

In [156]: def vectorized_app1(a,indices): # 3D linear indexing
...: m,n,r = a.shape
...: offset = n*r*np.arange(m)[:,None] + r*np.arange(n)
...: np.put(a,indices + offset,-1)
...:
...: def vectorized_app2(a,indices): # 2D linear indexing
...: m,n,r = a.shape
...: a.reshape(-1,r)[np.arange(m*n),indices.ravel()] = -1
...:

In [157]: # Generate random 3D array and the corresponding indices array
...: a = np.random.randint(0,99,(100,100,100))
...: indices = np.argmax(a, axis=2)
...:
...: # Make copies for feeding into functions
...: ac1 = a.copy()
...: ac2 = a.copy()
...:

In [158]: vectorized_app1(ac1,indices)

In [159]: vectorized_app2(ac2,indices)

In [160]: np.allclose(ac1,ac2)
Out[160]: True

In [161]: # Make copies for feeding into functions
...: ac1 = a.copy()
...: ac2 = a.copy()
...:

In [162]: %timeit vectorized_app1(ac1,indices)
1000 loops, best of 3: 311 µs per loop

In [163]: %timeit vectorized_app2(ac2,indices)
10000 loops, best of 3: 145 µs per loop

关于python - 用另一个索引多维数组的多个维度 - NumPy/Python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34551061/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com