gpt4 book ai didi

python - 获取numpy中多维切片沿特定轴的索引

转载 作者:太空宇宙 更新时间:2023-11-03 14:53:21 27 4
gpt4 key购买 nike

我有这样的代码:

import numpy as np    
b = np.random.choice([0, 1], size=(12, 10, 2), p=[0.5, 0.5]) > 0.5
a = np.ones((12, 10, 2, 6, 4))

a = a[b]
print(a.shape)

我想知道每个选择来自第 1 轴(上面的 10 个)的哪个位置,例如,

a[0, 0, 0] = 0(来自 b[:, 0, :])

a[0, 0, 1] = 3(来自 b[:, 3, :])

a[6, 3, 1] = 1(来自 b[:, 1, :])

等等

我该怎么做?

这是没有随机选择的简化版本:

import numpy as np

b = np.array([[0, 1], [1, 1]]) > 0.5
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

a = a[b] #gives [[3,4],[5,6],[7,8]]

# Desired result: 1,0,1 as each element came from that index of axis 1 of b

[1, 0, 1] # The index along the axis of the last dim of b for each selection in a

最佳答案

and I want to know which position along axis 1 (the 10 above) that each selection came from, ...

当您执行 a = a[b] 时,a 的新元素将是与之前的 True 值相关的元素。在您的随机 b 数组中。为此,您可以使用 numpy.where() methodb 上了解哪些在哪里,如下所示:

import numpy as np

b = np.random.choice([0, 1], size=(12, 10, 2), p=[0.5, 0.5]) > 0.5 #random choice
a = np.ones((12, 10, 2, 6, 4))

a = a[b] #obtain those randomly selected items
print(a.shape)

indexes = np.where(b==True)
print(indexes[1]) #the axis 1 you desire

请注意,如果您希望获得其他 b 轴(例如轴i),您应该像 indexes[i ]。另请注意,这每次都会给出不同的值,因为它是随机的。

但是,用更简单的示例对其进行测试,我们也得到了所需的 [1,0,1]:

import numpy as np
b = np.array([[0, 1], [1, 1]]) > 0.5
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

a = a[b] #gives result [[3,4], [5,6], [7,8]], so they are 1,0,1
print(a.shape) #gives (3, 2)

indexes = np.where(b==True)
print(indexes[1]) #gives us the desired 1,0,1

关于python - 获取numpy中多维切片沿特定轴的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45764751/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com