gpt4 book ai didi

go - 无法使用 TensorFlow Go API 进行预测

转载 作者:IT王子 更新时间:2023-10-29 01:47:16 24 4
gpt4 key购买 nike

我有一个使用 Tensorflow Python API 编码的 MLP。以下是代码片段:

# tf Graph input
x = tf.placeholder("float", [None, 11],name="x")
y = tf.placeholder("float", [None])

# Store layers weight & bias
weights = {
'h1': tf.Variable(tf.random_normal([11, 32], 0, 0.1)),
'h2': tf.Variable(tf.random_normal([32, 200], 0, 0.1)),
'out': tf.Variable(tf.random_normal([200, 1], 0, 0.1))
}

biases = {
'b1': tf.Variable(tf.random_normal([32], 0, 0.1)),
'b2': tf.Variable(tf.random_normal([200], 0, 0.1)),
'out': tf.Variable(tf.random_normal([1], 0, 0.1))
}

# Create model
def multilayer_perceptron(x, weights, biases):
# Hidden layer with RELU activation
layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
layer_1 = tf.nn.relu(layer_1)

# Hidden layer with RELU activation
layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
layer_2 = tf.nn.relu(layer_2)

# Output layer with linear activation
out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
return out_layer

# Construct model
pred = multilayer_perceptron(x, weights, biases)
pred = tf.identity(pred, name="pred")

模型已使用 saved_model_builder.SavedModelBuilder 方法进行训练和保存。使用 Python API 的预测可以使用以下代码完成:

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.loader.load(sess, ["tag"], "/tmp/saved_models")
my_pred= sess.graph.get_tensor_by_name('pred:0')
predictions = sess.run(my_pred, feed_dict={x: pred_data})
print("The prediction is", predictions)

我正在尝试使用以下代码片段使用 Go API 做出相同的预测:

df := []float32{9.5,0.0,7.5,0.0,0.0,2.0,0.0,0.0,0.0,0.0,1505292248.0}
tensor, terr := tf.NewTensor(df)
result, runErr := model.Session.Run(
map[tf.Output]*tf.Tensor{
model.Graph.Operation("x").Output(0): tensor,
},
[]tf.Output{
model.Graph.Operation("pred").Output(0),
},
nil,
)

但是,我遇到了以下错误:

Error running the session with input, err: In[0] is not a matrix
[[Node: MatMul = MatMul[T=DT_FLOAT, _output_shapes=[[?,32]], transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_x_0_0, Variable/read)]]

谁能指出这个错误的原因?

最佳答案

错误很明显:In[0] 不是矩阵

您的 In[0] 是:df := []float32{9.5,0.0,7.5,0.0,0.0,2.0,0.0,0.0,0.0,0.0,1505292248.0}

这是一维张量,不是矩阵。

matmul 节点要求它的两个参数都是矩阵,因此是二维张量。

因此,您必须更改 df 定义才能定义二维张量,如下所示:

df := [][]float32{{9.5},{0.0},{7.5},{0.0},{0.0},{2.0},{0.0},{0.0},{0.0},{0.0},{1505292248.0}}

关于如何思考/调试 tensorflow + go 代码的一个很好的引用是:https://pgaleone.eu/tensorflow/go/2017/05/29/understanding-tensorflow-using-go/

关于go - 无法使用 TensorFlow Go API 进行预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46368514/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com