gpt4 book ai didi

c++ - tensorflow :转置需要一个大小为 1 的 vector 。但输入(1)是一个大小为 2 的 vector

转载 作者:太空宇宙 更新时间:2023-11-04 12:46:21 24 4
gpt4 key购买 nike

我想用训练好的RNN语言模型做推理。所以:
我使用 C++ 加载了经过训练的模型图

tensorflow::MetaGraphDef graph_def;
TF_CHECK_OK(ReadBinaryProto(Env::Default(), path_to_graph, &graph_def));
TF_CHECK_OK(session->Create(graph_def.graph_def()));

通过以下方式加载模型参数:

Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = path_to_ckpt;

TF_CHECK_OK(session_->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor} },{},{graph_def.saver_def().restore_op_name()},nullptr));

到目前为止,一切正常。
然后我想计算节点“output/output_batch_major”的值:

TF_CHECK_OK(session->Run(inputs,{"output/output_batch_major"},{"post_control_dependencies"}, &outputs));

我得到了错误:

2018-07-13 14:13:36.793495: F tf_lm_model_loader.cc:190] Non-OK-status: session->Run(inputs,{"output/output_batch_major"},{"post_control_dependencies"}, &outputs) status: Invalid argument: transpose expects a vector of size 1. But input(1) is a vector of size 2
[[Node: extern_data/placeholders/delayed/sequence_mask_time_major/transpose = Transpose[T=DT_BOOL, Tperm=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](extern_data/placeholders/delayed/SequenceMask/Less, extern_data/placeholders/delayed/sequence_mask_time_major/transpose/perm)]]
Aborted (core dumped)

我使用张量板检查了图形,extern_data/placeholders/delayed/sequence_mask_time_major/transpose/perm 是一个大小为 2 的 Tensor,这个 Tensor 是 输入(1)中的错误?我该如何解决这个问题?
有什么想法吗?提前致谢!

最佳答案

我的预测器的输入张量也有类似的问题。我将维度扩大一倍,问题就解决了。我建议首先在 python 中运行预测器。这有助于确定您传递给预测器的输入张量的大小。然后,在 C++ 中复制完全相同的大小。此外,根据您的代码片段,我不确定您如何定义 Run 方法的输入。我在我的代码中定义如下:

std::vector<std::pair<std::string, tensorflow::Tensor>> input = {
{"input_1", input_tensor }
};

其中“input_1”是我的输入层的名称。我希望这有帮助。

关于c++ - tensorflow :转置需要一个大小为 1 的 vector 。但输入(1)是一个大小为 2 的 vector ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51326140/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com