gpt4 book ai didi

c++ - PyTorch C++ API中的 `randperm`不应该返回默认类型为int的张量吗?

转载 作者:行者123 更新时间:2023-12-02 09:59:19 27 4
gpt4 key购买 nike

当我尝试使用C++ PyTorch API使用randperm生成置换整数索引的列表时,所得张量具有CPUFloatType{10}的元素类型而不是整数类型:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
退货
 9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
不能用于张量索引,因为元素类型是float而不是整数类型。当tryig使用 my_tensor.index(shuffled_indices)我得到
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
环境:
  • python-pytorch,Arch Linux上的1.6.0-2版本
  • g++(GCC)10.1.0

  • 为什么会这样?

    最佳答案

    这是因为用torch创建的任何张量的默认类型始终是float。否则,必须使用TensorOptions参数struct指定它:

    int N_SAMPLES = 10;               
    torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
    cout << shuffled_indices.dtype() << endl;
    >>> long

    关于c++ - PyTorch C++ API中的 `randperm`不应该返回默认类型为int的张量吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63500671/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com