gpt4 book ai didi

c++ - PyTorch for C++ 中的按行元素索引

转载 作者:搜寻专家 更新时间:2023-10-31 02:04:26 24 4
gpt4 key购买 nike

我正在使用 C++ frontend for PyTorch并且正在努力解决一个相对基本的索引问题。

我有一个 8 by 6 张量,如下所示:

[ Variable[CUDAFloatType]{8,6} ] 
0 1 2 3 4 5
0 1.7107e-14 4.0448e-17 4.9708e-06 1.1664e-08 9.9999e-01 2.1857e-20
1 1.8288e-14 5.9356e-17 5.3042e-06 1.2369e-08 9.9999e-01 2.4799e-20
2 2.6828e-04 9.0390e-18 1.7517e-02 1.0529e-03 9.8116e-01 6.7854e-26
3 5.7521e-10 3.1037e-11 1.5021e-03 1.2304e-06 9.9850e-01 1.4888e-17
4 1.7811e-13 1.8383e-15 1.6733e-05 3.8466e-08 9.9998e-01 5.2815e-20
5 9.6191e-06 2.6217e-23 3.1345e-02 2.3024e-04 9.6842e-01 2.9435e-34
6 2.2653e-04 8.4642e-18 1.6085e-02 9.7405e-04 9.8271e-01 6.3059e-26
7 3.8951e-14 2.9903e-16 8.3518e-06 1.7974e-08 9.9999e-01 3.6993e-20

我有另一个张量,其中只有 8 个元素,例如:

[ Variable[CUDALongType]{8} ] 
0
3
4
4
4
4
4
4

我想使用第二张量索引我的第一个张量的行以产生:

        0           
0 1.7107e-14
1 1.2369e-08
2 9.8116e-01
3 9.9850e-01
4 9.9998e-01
5 9.6842e-01
6 9.8271e-01
7 9.9999e-01

我尝试了几种不同的方法,包括 index_select但它似乎产生了与输入具有相同维度的输出 (8x6)。

在 Python 中,我认为我可以使用 Python 的内置索引进行索引,如下所述:https://github.com/pytorch/pytorch/issues/1080

不幸的是,在 C++ 中,我只能用标量(零维张量)索引张量,所以我认为这种方法在这里不适合我。

如何在不使用循环的情况下实现我想要的结果?

最佳答案

事实证明,您可以通过几种不同的方式做到这一点。一种使用 gather,一种使用 index。来自PyTorch discussions我问了同样的问题:

使用 torch::gather

auto x = torch::randn({8, 6});
int64_t idx_data[8] = { 0, 3, 4, 4, 4, 4, 4, 4 };
auto idx = x.type().toScalarType(torch::kLong).tensorFromBlob(idx_data, 8);
auto result = x.gather(1, idx.unsqueeze(1));

使用特定于 C++ 的 torch::index

auto x = torch::randn({8, 6});
int64_t idx_data[8] = { 0, 3, 4, 4, 4, 4, 4, 4 };
auto idx = x.type().toScalarType(torch::kLong).tensorFromBlob(idx_data, 8);
auto rows = torch::arange(0, x.size(0), torch::kLong);
auto result = x.index({rows, idx});

关于c++ - PyTorch for C++ 中的按行元素索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53507039/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com