gpt4 book ai didi

c++ - 在割炬C++中创建BoolTensor蒙版

转载 作者:行者123 更新时间:2023-12-01 14:57:14 24 4
gpt4 key购买 nike

我正在尝试创建BoolTensor类型的C++中的火炬 mask 。维度n中的第一个one元素需要为False,其余元素需要为True
这是我的尝试,但我不知道这是否正确(大小是元素数):

src_mask = torch::BoolTensor({6, 1});
src_mask[:size,:] = 0;
src_mask[size:,:] = 1;

最佳答案

我不确定在这里确切地了解您的目标,因此这是我将伪代码转换为C++的最佳尝试。
首先,使用libtorch,您可以通过torch::TensorOptions结构声明张量的类型(类型名称以小写k开头)
其次,借助torch::Tensor::slice函数,可以进行类似python的 slice (请参阅herethere)。
最后,您会得到类似:

// Creates a tensor of boolean, initially all ones
auto options = torch::TensorOptions().dtype(torch::kBool));
torch::Tensor bool_tensor = torch::ones({6,1}, options);
// Set the slice to 0
int size = 3;
bool_tensor.slice(/*dim=*/0, /*start=*/0, /*end=*/size) = 0;

std::cout << bool_tensor << std::endl;
请不要因为这会将第一个 size行设置为0。我假设这就是“x维度中的第一个元素”的意思。

关于c++ - 在割炬C++中创建BoolTensor蒙版,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63186307/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com