gpt4 book ai didi

python - numpy索引解释 ndarray[(4, 2), (5, 3)]

转载 作者:行者123 更新时间:2023-12-04 01:09:14 25 4
gpt4 key购买 nike

问题

请帮助理解将元组 (i, j) 编入 ndarray 的 Numpy 索引的设计决策或合理性。

背景

当索引是单个元组(4, 2)时,则(i=row, j=column)。

shape = (6, 7)
X = np.zeros(shape, dtype=int)
X[(4, 2)] = 1
X[(5, 3)] = 1
print("X is :\n{}\n".format(X))
---
X is :
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 0 0 0 0] <--- (4, 2)
[0 0 0 1 0 0 0]] <--- (5, 3)

但是,当索引是多个元组 (4, 2), (5, 3) 时,则 (i=row, j=row) for (4, 2) and (i=column, j=column) for (5, 3).

shape = (6, 7)
Y = np.zeros(shape, dtype=int)
Y[(4, 2), (5, 3)] = 1
print("Y is :\n{}\n".format(Y))
---
Y is :
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 1 0 0 0] <--- (2, 3)
[0 0 0 0 0 0 0]
[0 0 0 0 0 1 0] <--- (4, 5)
[0 0 0 0 0 0 0]]

It means you are constructing a 2d array R, such that R=A[B, C].This means that the value forrij=abijcij.

So it means that the item located at R[0,0] is the item in A withas row index B[0,0] and as column index C[0,0]. The item R[0,1]is the item in A with row index B[0,1] and as column indexC[0,1], etc.

multi_index: A tuple of integer arrays, one array for each dimension.

为什么不总是(i=行,j=列)?如果一直是(i=row, j=column)会怎么样?


已更新

根据 Akshay 和@DaniMesejo 的回答,理解:

X[
(4), # dimension 2 indices with only 1 element
(2) # dimension 1 indices with only 1 element
] = 1

Y[
(4, 2, ...), # dimension 2 indices
(5, 3, ...) # dimension 1 indices (dimension 0 is e.g. np.array(3) whose shape is (), in my understanding)
] = 1

最佳答案

很容易理解它是如何工作的(以及这个设计决策背后的动机)。

Numpy 将其 ndarray 存储为连续的内存块。每个元素在前一个元素之后每隔 n 个字节按顺序存储。

(引用自 excellent SO post 的图片)

所以如果你的 3D 数组看起来像这样 -

enter image description here

然后在内存中将其存储为 -

enter image description here

当检索一个元素(或一个元素 block )时,NumPy 计算有多少 strides (bytes) 需要遍历得到下一个元素in that direction/axis .所以,对于上面的例子,axis=2它必须遍历 8 个字节(取决于 datatype )但对于 axis=1它必须遍历 8*4字节,和 axis=0它需要8*8字节。

考虑到这一点,让我们看看您要尝试做什么。

print(X)
print(X.strides)
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 0 0 0 0]
[0 0 0 1 0 0 0]]

#Strides (bytes) required to traverse in each axis.
(56, 8)

对于您的数组,获取 axis=0 中的下一个元素,我们需要遍历56 bytes ,以及 axis=1 中的下一个元素, 我们需要 8 bytes .

当您索引 (4,2) 时, NumPy 正在运行 56*4 axis=0 中的字节数和 8*2 axis=1 中的字节数访问它。同样,如果你想访问 (4,2)(5,3) ,它将必须访问 56*(4,5)axis=08*(2,3)axis=1 .

这就是设计之所以如此的原因,因为它与 NumPy 实际上使用 strides 索引元素的方式一致。 .

X[(axis0_indices), (axis1_indices), ..]

X[(4, 5), (2, 3)] #(row indices), (column indices)
array([1, 1])

通过这种设计,也可以轻松扩展到更高维度的张量(例如 8 维数组)! 如果您分别提及每个索引元组,则需要元素 * 计算的维数才能获取这些元组。使用这种设计时,它可以将步幅值广播到每个轴的元组并更快地获取这些值!

关于python - numpy索引解释 ndarray[(4, 2), (5, 3)],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65389298/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com