gpt4 book ai didi

python - Numpy:用同一行中其他元素的最大值替换一行中的每个元素

转载 作者:太空宇宙 更新时间:2023-11-03 13:30:03 24 4
gpt4 key购买 nike

假设我们有一个像这样的二维数组:

>>> a
array([[1, 1, 2],
[0, 2, 2],
[2, 2, 0],
[0, 2, 0]])

对于每一行,我想用同一行中其他两个元素中的最大值替换每个元素。

我已经找到了如何使用 numpy.amax 和恒等数组分别对每一列执行此操作,如下所示:

>>> np.amax(a*(1-np.eye(3)[0]), axis=1)
array([ 2., 2., 2., 2.])
>>> np.amax(a*(1-np.eye(3)[1]), axis=1)
array([ 2., 2., 2., 0.])
>>> np.amax(a*(1-np.eye(3)[2]), axis=1)
array([ 1., 2., 2., 2.])

但我想知道是否有一种方法可以避免 for 循环并直接获得结果,在这种情况下应该如下所示:

>>> numpy_magic(a)
array([[2, 2, 1],
[2, 2, 2],
[2, 2, 2],
[2, 0, 2]])

编辑:在控制台中玩了几个小时后,我终于想出了我正在寻找的解决方案。准备好接受一些令人兴奋的一行代码:

np.amax(a[[range(a.shape[0])]*a.shape[1],:][(np.eye(a.shape[1]) == 0)[:,[range(a.shape[1])*a.shape[0]]].reshape(a.shape[1],a.shape[0],a.shape[1])].reshape((a.shape[1],a.shape[0],a.shape[1]-1)),axis=2).transpose()
array([[2, 2, 1],
[2, 2, 2],
[2, 2, 2],
[2, 0, 2]])

Edit2:Paul 提出了一个更具可读性和更快的替代方案:

np.max(a[:, np.where(~np.identity(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

在对这 3 个备选方案进行计时后,Paul 的两个解决方案在每种情况下都快 4 倍(我已经针对 200 行的 2、3 和 4 列进行了基准测试)。祝贺这些令人惊叹的代码!

最后编辑(抱歉):在用更快的 np.eye 替换 np.identity 之后,我们现在有了最快和最简洁的解决方案:

np.max(a[:, np.where(~np.eye(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)

最佳答案

这里有两种解决方案,一种专为 max 设计,另一种更通用,也适用于其他操作。

利用每一行中除了可能的一个最大值之外的所有元素都是整行的最大值这一事实,我们可以使用 argpartition 廉价地找到最大的两个元素的索引。然后在最大的位置,我们将第二大的值和其他所有地方的最大值放在一起。也适用于超过 3 列。

>>> a
array([[6, 0, 8, 8, 0, 4, 4, 5],
[3, 1, 5, 0, 9, 0, 3, 6],
[1, 6, 8, 3, 4, 7, 3, 7],
[2, 1, 6, 2, 9, 1, 8, 9],
[7, 3, 9, 5, 3, 7, 4, 3],
[3, 4, 3, 5, 8, 2, 2, 4],
[4, 1, 7, 9, 2, 5, 9, 6],
[5, 6, 8, 5, 5, 3, 3, 3]])
>>>
>>> M, N = a.shape
>>> result = np.empty_like(a)
>>> largest_two = np.argpartition(a, N-2, axis=-1)
>>> rng = np.arange(M)
>>> result[...] = a[rng, largest_two[:, -1], None]
>>> result[rng, largest_two[:, -1]] = a[rng, largest_two[:, -2]]>>>
>>> result
array([[8, 8, 8, 8, 8, 8, 8, 8],
[9, 9, 9, 9, 6, 9, 9, 9],
[8, 8, 7, 8, 8, 8, 8, 8],
[9, 9, 9, 9, 9, 9, 9, 9],
[9, 9, 7, 9, 9, 9, 9, 9],
[8, 8, 8, 8, 5, 8, 8, 8],
[9, 9, 9, 9, 9, 9, 9, 9],
[8, 8, 6, 8, 8, 8, 8, 8]])

此解决方案取决于最大值的特定属性。

例如,一个更通用的解决方案也适用于 sum 而不是 max。将 a 的两个副本粘合在一起(并排,而不是彼此重叠)。所以行类似于 a0 a1 a2 a3 a0 a1 a2 a3。对于索引 x,我们可以通过切片 [x+1:x+4] 得到除 ax 之外的所有索引。为了进行矢量化,我们使用 stride_tricks:

>>> a
array([[2, 6, 0],
[5, 0, 0],
[5, 0, 9],
[6, 4, 4],
[5, 0, 8],
[1, 7, 5],
[9, 7, 7],
[4, 4, 3]])
>>> M, N = a.shape
>>> aa = np.c_[a, a]
>>> ast = np.lib.stride_tricks.as_strided(aa, (M, N+1, N-1), aa.strides + aa.strides[1:])
>>> result = np.max(ast[:, 1:, :], axis=-1)
>>> result
array([[6, 2, 6],
[0, 5, 5],
[9, 9, 5],
[4, 6, 6],
[8, 8, 5],
[7, 5, 7],
[7, 9, 9],
[4, 4, 4]])

# use sum instead of max
>>> result = np.sum(ast[:, 1:, :], axis=-1)
>>> result
array([[ 6, 2, 8],
[ 0, 5, 5],
[ 9, 14, 5],
[ 8, 10, 10],
[ 8, 13, 5],
[12, 6, 8],
[14, 16, 16],
[ 7, 7, 8]])

关于python - Numpy:用同一行中其他元素的最大值替换一行中的每个元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48393431/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com