gpt4 book ai didi

python - ndarrays的条件过滤

转载 作者:行者123 更新时间:2023-12-03 16:20:40 25 4
gpt4 key购买 nike

假设我有以下数组:

Input = np.array([[[[17.63,  0.  , -0.71, 29.03],
[17.63, -0.09, 0.71, 56.12],
[ 0.17, 1.24, -2.04, 18.49],
[ 1.41, -0.8 , 0.51, 11.85],
[ 0.61, -0.29, 0.15, 36.75]]],


[[[ 0.32, -0.14, 0.39, 24.52],
[ 0.18, 0.25, -0.38, 18.08],
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ],
[ 0.43, 0. , 0.3 , 0. ]]],


[[[ 0.75, -0.38, 0.65, 19.51],
[ 0.37, 0.27, 0.52, 24.27],
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ]]]])

Input.shape
(3, 1, 5, 4)

与此 Input 数组一起是所有输入的相应 Label 数组,因此:

Label = np.array([0, 1, 2])

Label.shape
(3,)

我需要一些方法来检查 Input 的所有嵌套数组,以仅选择具有足够数据点的数组。

我的意思是我想要一种方法来消除(或者我应该说删除)最后 3 行的条目全为零的所有数组。在执行此操作的同时,消除该数组的相应 Label

预期输出:

Input_filtered
array([[[[17.63, 0. , -0.71, 29.03],
[17.63, -0.09, 0.71, 56.12],
[ 0.17, 1.24, -2.04, 18.49],
[ 1.41, -0.8 , 0.51, 11.85],
[ 0.61, -0.29, 0.15, 36.75]]],


[[[ 0.32, -0.14, 0.39, 24.52],
[ 0.18, 0.25, -0.38, 18.08],
[ 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ],
[ 0.43, 0. , 0.3 , 0. ]]]])

Label_filtered
array([0, 1])

我需要什么技巧?

最佳答案

您应该只能使用矢量化 numpy 命令来执行此操作。

filter_ = np.any(Input[:, :, -3:], axis=(1, 2, 3))
labels_filtered = Label[filter_]
inputs_filtered = Input[[filter_]]

对于您提供的示例集,与 anon01 的解决方案相比,每个循环产生 4.95 µs ± 9.69 ns(每个循环 100000 个循环),而 anon01 的解决方案每个循环产生 17.1 µs ± 111 ns(每个循环 100000 个循环)。改进应该在更大的阵列上更加显着。

如果您的数据具有不同的维度,您可以更改轴参数。对于任意数量的轴,它可能如下所示:

filter_ = np.any(Input[:, :, -3:], axis=tuple(range(1, Input.ndim)))

关于python - ndarrays的条件过滤,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63402552/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com