gpt4 book ai didi

c++ - PyTorch C++ 扩展 : Accessing data for Half Tensors

转载 作者:太空宇宙 更新时间:2023-11-04 12:31:11 25 4
gpt4 key购买 nike

我正在尝试使用 C++ 张量 API 为 PyTorch 编写 C++/CUDA 扩展,我希望我的代码能够同时使用 float32 和 float16(半精度)。我不确定如何访问来自 Python 的半张量的数据指针。

这是我对浮点张量的处理方式:

// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();

这是我对半张量的尝试:

// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();

// PyTorch float16 type
// error: no instance of function template "at::Tensor::data"
A.data<torch::ScalarType::Half>();

// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());

我尝试查看 C++ api 源代码,但找不到任何其他看起来像 float16 类型的内容。

系统信息: python 3.6.2 torch 1.0.1

最佳答案

结果证明正确的类型是 at::Half

关于c++ - PyTorch C++ 扩展 : Accessing data for Half Tensors,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58613427/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com