gpt4 book ai didi

c++ - 如何在 C++ 中设置输入张量的值?

转载 作者:塔克拉玛干 更新时间:2023-11-03 00:23:28 25 4
gpt4 key购买 nike

我正在尝试通过 ios 上的预训练模型运行示例。 session->Run() 根据我的理解将张量作为输入。我已经初始化了一个张量,但我该如何设置它的值呢?我没有太多使用 C++ 的经验。

我已经成功创建了一个测试模型,它接受形状为 {1, 1, 10} 的 3 维张量。

我从 Tensorflow 的简单示例中提取了以下代码行来创建输入张量。

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm#L189

tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1,1,10}));

从这里开始,我不知道如何设置 input_tensor 的数据。我想将张量设置为 {{{.0, .1, .2, .3, .4, .5, .6, .7, .8, .9}}}

最佳答案

我有一个类似的问题,并试图在 C++ 中为在 Python 中训练的模型设置张量输入值。该模型是一个带有一个隐藏层的简单神经网络,用于学习计算异或运算。

我首先按照这篇精彩帖子的步骤 1-4 创建了一个包含图结构和模型参数的输出图文件:https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.j4l51ptvb .

然后在C++(TensorFlow iOS 简单示例)中,我使用了以下代码:

tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({4,2}));

// input_tensor_mapped is an interface to the data of a tensor and used to copy data into the tensor
auto input_tensor_mapped = input_tensor.tensor<float, 2>();

// set the (4,2) possible input values for XOR
input_tensor_mapped(0, 0) = 0.0;
input_tensor_mapped(0, 1) = 0.0;
input_tensor_mapped(1, 0) = 0.0;
input_tensor_mapped(1, 1) = 1.0;
input_tensor_mapped(2, 0) = 1.0;
input_tensor_mapped(2, 1) = 0.0;
input_tensor_mapped(3, 0) = 1.0;
input_tensor_mapped(3, 1) = 1.0;

tensorflow::Status run_status = session->Run({{input_layer, input_tensor}},
{output_layer}, {}, &outputs);

在此之后,GetTopN(output->flat<float>(), kNumResults, kThreshold, &top_results);返回相同的 4 个值(0.94433498、0.94425952、0.06565627、0.05823805),就像我训练模型后在 top_results 中用于异或的 Python 测试代码中一样。

因此,如果您的张量的形状是 {1,1,10},您可以按如下方式设置值:

auto input_tensor_mapped = input_tensor.tensor<float, 3>();
input_tensor_mapped(0, 0, 0) = 0.0;
input_tensor_mapped(0, 0, 1) = 0.1;
....
input_tensor_mapped(0, 0, 9) = 0.9;

来源:How do I pass an OpenCV Mat into a C++ Tensorflow graph? 的答案很有帮助。

关于c++ - 如何在 C++ 中设置输入张量的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38857282/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com