gpt4 book ai didi

python - 如何在numpy中找到给定行值的列索引?

转载 作者:行者123 更新时间:2023-11-28 18:32:27 25 4
gpt4 key购买 nike

我想在 numpy 中执行一些相对简单的操作:

  • 如果行中有1,则返回包含1的列的索引+1。
  • 如果一行中有零个或多个1,则返回0。

但是我最终得到了一个相当复杂的代码:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
result[idx] = np.where(predictions[idx,:]==1)[1] + 1

如果我打印result,程序会打印[ 1. 0. 4. 0.] 这是期望的结果。

我想知道是否有更简单的方法使用 numpy 编写最后一行。

最佳答案

我不确定这是否更好,但您可以尝试使用 argmax为了那个原因。此外,您不需要使用 for 循环和 np.where 来获取有效索引:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1., 0., 4., 0.])

或者您可以使用 np.whereargmax 在一行中完成所有这些工作:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])

关于python - 如何在numpy中找到给定行值的列索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35613034/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com