gpt4 book ai didi

c++ - 如何从libtorch输出中删除乘法器并显示最终结果?

转载 作者:行者123 更新时间:2023-12-02 09:59:14 25 4
gpt4 key购买 nike

当我尝试在屏幕上显示/打印一些张量时,我遇到类似于以下内容的东西,而不是得到最终结果,libtorch似乎用张数显示张量(即0.01*之类,如下所示):

offsets.shape: [1, 4, 46, 85]
probs.shape: [46, 85]
offsets: (1,1,.,.) =
0.01 *
0.1006 1.2322
-2.9587 -2.2280

(1,2,.,.) =
0.01 *
1.3772 1.3971
-1.2813 -0.8563

(1,3,.,.) =
0.01 *
6.2367 9.2561
3.5719 5.4744

(1,4,.,.) =
0.2901 0.2963
0.2618 0.2771
[ CPUFloatType{1,4,2,2} ]
probs: 0.0001 *
1.4593 1.0351
6.6782 4.9104
[ CPUFloatType{2,2} ]
如何禁用此行为并获得最终输出?我试图将其显式转换为浮点型,希望这将导致最终的输出被存储/显示,但这也不起作用。

最佳答案

基于libtorch的输出张量的源代码,在存储库中搜索“*”字符串后,事实证明,此“ pretty-print ”是在aten / src / ATen / core / Formatting.cpp转换单元中完成的。标尺和星号在此处前置:

static void printScale(std::ostream & stream, double scale) {
FormatGuard guard(stream);
stream << defaultfloat << scale << " *" << std::endl;
}
然后在Tensor的所有坐标上除以 scale:
if(scale != 1) {
printScale(stream, scale);
}
double* tensor_p = tensor.data_ptr<double>();
for(int64_t i = 0; i < tensor.size(0); i++) {
stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
}
基于此翻译单元,这是完全不可配置的。
我想您这里有两个选择:
  • 调整功能并最小化编辑现有功能,以满足您的要求。
  • 删除(或添加#ifdef)Formatting.cpp中的Tensor的<<运算符重载,并提供您自己的实现。但是,在构建libtorch时,必须将其链接到包含该方法的实现的目标。

  • 但是,这两种选择都要求您更改第三方代码,我认为这很糟糕。

    关于c++ - 如何从libtorch输出中删除乘法器并显示最终结果?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63559554/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com