gpt4 book ai didi

c++ - 如何通过单次运行将多个输入样本提供给 C++ 中的 tensorflow 模型

转载 作者:太空宇宙 更新时间:2023-11-04 12:45:54 25 4
gpt4 key购买 nike

使用 python tensorflow API 可以做这样的事情:

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
adder_node = a + b

sess.run(adder_node, {a: [1,3], b: [2, 4]})

结果:[ 3. 7.]

在 C++ 中是否有任何方法可以通过一次调用 Run 方法将多个输入提供给模型?我尝试使用 feed_dicts 的 std::vector

// prepare tensorflow inputs
std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict;
for(size_t i = 0; i < noutput_items; i++) {
tensorflow::TensorShape data_shape({1, d_vlen_in});
tensorflow::Tensor n_tensor(tensorflow::DT_FLOAT, data_shape);
auto n_data = n_tensor.flat<float>().data();
for(int j = 0 ; j < d_vlen_in ; j++) {
n_data[j] = in[j];
}
feed_dict.push_back(std::make_pair(d_layer_in, n_tensor));
in += d_vlen_in;
}

// prepare tensorflow outputs
std::vector<tensorflow::Tensor> outputs;

TF_CHECK_OK(d_session->Run(feed_dict, {d_layer_out}, {}, &outputs));

d_layer_ind_layer_outstd::strings,“input”是我的输入层/占位符。

但是它失败了:

Non-OK-status: d_session->Run(feed_dict, {d_layer_out}, {}, &outputs) status: Invalid argument: Endpoint "input" fed more than once.

有人知道这样做的方法吗?我的主要目标是提高吞吐量。

最佳答案

所以我找到了答案,而且很简单。feed 字典的几个元素用于设置多个输入变量,与 python 中相同。然而,输入张量的第一(或零)维度是批量维度,在某些情况下可以用作输入信号的时间维度。

// prepare tensorflow inputs
// dimension 0 is the batch dimension, i.e., time dimension
tensorflow::TensorShape data_shape({noutput_items, d_vlen_in});
tensorflow::Tensor in_tensor(tensorflow::DT_FLOAT, data_shape);
auto in_tensor_data = in_tensor.flat<float>().data();
for(size_t i = 0; i < noutput_items; i++) {
for(int j = 0 ; j < d_vlen_in ; j++) {
in_tensor_data[(i*d_vlen_in)+j] = in[(i*d_vlen_in)+j];
}
}
std::vector<std::pair<std::string, tensorflow::Tensor>> feed_dict = {
{ d_layer_in, in_tensor },
};

// prepare tensorflow outputs
std::vector<tensorflow::Tensor> outputs;

TF_CHECK_OK(d_session->Run(feed_dict, {d_layer_out}, {}, &outputs));

这当然会引入一些缓冲/延迟,但它让我的吞吐量增加了大约 1000 倍。

关于c++ - 如何通过单次运行将多个输入样本提供给 C++ 中的 tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51611232/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com