gpt4 book ai didi

python - 从一维张量中提取前 k 个值索引

转载 作者:行者123 更新时间:2023-12-04 20:59:38 24 4
gpt4 key购买 nike

给定 Torch 中的一维张量 ( torch.Tensor ),其中包含可以比较的值(比如浮点数),我们如何提取该张量中前 k 个值的索引?
除了蛮力方法,我正在寻找 Torch/lua 提供的一些 API 调用,它可以有效地执行此任务。

最佳答案

截至拉取请求 #496 Torch 现在包含一个名为 torch.topk 的内置 API。 .例子:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
1
2
3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
2
4
6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
9
8
7
[torch.DoubleTensor of size 3]

在撰写本文时,CPU 实现遵循 sort and narrow approach (有计划在 future 改进它)。话虽如此,目前正在为 cutorch 优化 GPU 实现 reviewed .

关于python - 从一维张量中提取前 k 个值索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34750268/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com