gpt4 book ai didi

python - 使用 numpy.where 防止越界

转载 作者:行者123 更新时间:2023-12-01 01:36:06 25 4
gpt4 key购买 nike

我正在尝试根据索引数组查找数组中的值。该索引数组可以包含可能超出范围的索引。在这种情况下,我想返回一个特定值(此处为 0)。

(我可以使用 for 循环,但这太慢了。)

所以我正在这样做:

data = np.arange(1000).reshape(10, 10, 10)
i = np.arange(9).reshape(3, 3)
i[0, 0] = 10
condition = (i[:, 0] < 10) & (i[:, 1] < 10) & (i[:, 2] < 10)
values = np.where(condition, data[i[:, 0], i[:, 1], i[:, 2]], 0)

但是我仍然收到越界错误:

IndexError: index 10 is out of bounds for axis 0 with size 10

我猜这是因为第二个参数不是延迟计算的,而是在函数调用之前计算的。

numpy 中是否有解决方案可以根据条件访问数组但仍保留顺序?通过保留顺序,我的意思是我无法首先过滤数组,因为我可能会失去最终结果中的顺序。最后,在该特定示例中,当索引超出范围时,我仍然希望值数组包含 0。所以最终的结果是:

array([ 0, 345, 678])

最佳答案

索引数组的每一列都存储每个维度的索引。我们可以生成有效掩码(在边界范围内)并将其中的无效掩码设置为零。即越界情况将通过 [0,0,0] 进行索引,然后让数组通过此修改后的版本进行索引,最后再次使用掩码来重置无效的情况,就像这样 -

shp = data.shape
valid_mask = (i < shp).all(1)
i[~valid_mask] = 0
out = np.where(valid_mask,data[tuple(i.T)],0)

在不更改 i 的情况下对其进行修改的紧凑版本将是 -

out = np.where(valid_mask,data[tuple(np.where(valid_mask,i.T,0))],0)

关于python - 使用 numpy.where 防止越界,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52381171/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com