gpt4 book ai didi

arrays - 查找数组中最大值索引的最快方法是什么?

转载 作者:行者123 更新时间:2023-11-29 08:17:34 24 4
gpt4 key购买 nike

我有一个 f32 类型的二维数组(来自 ndarray::ArrayView2),我想找到每一行中最大值的索引,并将索引值到另一个数组。

Python 中的等价物是这样的:

import numpy as np

for i in range (0, max_val, batch_size):
sims = xp.dot(batch, vectors.T)
# sims is the dot product of batch and vectors.T
# the shape is, for example, (1024, 10000)

best_rows[i: i+batch_size] = sims.argmax(axis = 1)

在 Python 中,函数 .argmax 非常快,但我在 Rust 中没有看到类似的函数。最快的方法是什么?

最佳答案

考虑一般 Ord 类型的简单情况:答案会略有不同,具体取决于您是否知道值是 Copy,但这是代码:

fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

基本思想是我们将数组中的每个项目(实际上,一个切片——无论它是 Vec 还是数组或更奇特的东西都没有关系)与其索引配对,使用 std::iter::Iterator 函数只根据值(不是索引)找到最大值,然后只返回索引。如果切片为空 None 将被返回。根据文档,将返回最右边的索引;如果您需要最左边的,请执行 rev() after enumerate()

rev()enumerate()max_by_key()max_by() 记录在 here ; slice::iter() 被记录为 here(但作为 rust 开发者,在没有文档的情况下,它需要出现在你要记忆的 list 上); mapOption::map() 记录的 here(同上)。哦,cmpOrd::cmp 但大多数时候您可以使用不需要它的 Copy 版本(例如,如果你在比较整数)。


现在问题来了:f32 不是 Ord 因为 IEEE float 的工作方式。大多数语言都忽略了这一点并且有微妙的错误算法。在 Ord 上提供总订单的最受欢迎的箱子(通过声明所有 NaN 都相等,并且大于所有数字)似乎是 ordered-float 。假设它实现正确,它应该非常轻量级。它确实引入了 num_traits,但这是最流行的数字库的一部分,因此很可能已经被其他依赖项引入。

在这种情况下,您可以通过将 ordered_float::OrderedFloat(元组类型的“构造函数”)映射到切片迭代器(slice.iter().map( ordered_float::OrderedFloat)).由于您只需要最大元素的位置,因此之后无需提取 f32。

关于arrays - 查找数组中最大值索引的最快方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57813951/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com