gpt4 book ai didi

python - 恢复和评估 Tensorflow 模型

转载 作者:行者123 更新时间:2023-11-30 09:52:43 24 4
gpt4 key购买 nike

我已经在训练期间保存了这个模型,但我很难加载它并评估它。

我尝试了一些不同的方法,但无法加载保存的模型并对其进行评估以获取对某些测试样本(图像文件)的预测。

有人可以帮忙吗?正如我所见,这似乎并不那么难,但我想念任何正确的事情。

#!/usr/bin/python
import tensorflow as tf

BATCH_SIZE = 128
NUM_EXAMPLES_PER_EPOCH = 50000
VALIDATION_SIZE = 10000
WIDTH = 128
HEIGHT = 64
CHANNELS = 3
CLASSES = 10
NUMBERS = 4


def inference(inputs):

with tf.variable_scope("conv_pool_1"):
kernel = tf.get_variable(name="kernel",
shape=[5, 5, 3, 48],
initializer=tf.truncated_normal_initializer(stddev=0.05),
dtype=tf.float32)
biases = tf.get_variable(name="biases",
shape=[48],
initializer=tf.constant_initializer(value=0.),
dtype=tf.float32)
conv = tf.nn.conv2d(input=inputs,
filter=kernel,
strides=[1, 1, 1, 1],
padding="SAME")
conv_bias = tf.nn.bias_add(value=conv,
bias=biases,
name="add_biases")
relu = tf.nn.relu(conv_bias)
pool = tf.nn.max_pool(value=relu,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
name="pooling")

with tf.variable_scope("conv_pool_2"):
kernel = tf.get_variable(name="kernel",
shape=[5, 5, 48, 64],
initializer=tf.truncated_normal_initializer(stddev=0.05),
dtype=tf.float32)
biases = tf.get_variable(name="biases",
shape=[64],
initializer=tf.constant_initializer(value=0.),
dtype=tf.float32)
conv = tf.nn.conv2d(input=pool,
filter=kernel,
strides=[1, 1, 1, 1],
padding="SAME")
conv_bias = tf.nn.bias_add(value=conv,
bias=biases,
name="add_biases")
relu = tf.nn.relu(conv_bias)
pool = tf.nn.max_pool(value=relu,
ksize=[1, 2, 1, 1],
strides=[1, 2, 1, 1],
padding="SAME",
name="pooling")
with tf.variable_scope("conv_pool_3"):
kernel = tf.get_variable(name="kernel",
shape=[5, 5, 64, 128],
initializer=tf.truncated_normal_initializer(stddev=0.05),
dtype=tf.float32)
biases = tf.get_variable(name="biases",
shape=[128],
initializer=tf.constant_initializer(value=0.),
dtype=tf.float32)
conv = tf.nn.conv2d(input=pool,
filter=kernel,
strides=[1, 1, 1, 1],
padding="SAME")
conv_bias = tf.nn.bias_add(value=conv,
bias=biases,
name="add_biases")
relu = tf.nn.relu(conv_bias)
pool = tf.nn.max_pool(value=relu,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
name="pooling")
reshape = tf.reshape(pool,
shape=[BATCH_SIZE, -1],
name="reshape")
dims = reshape.get_shape().as_list()[-1]
with tf.variable_scope("fully_conn"):
weights = tf.get_variable(name="weights",
shape=[dims, 2048],
initializer=tf.truncated_normal_initializer(stddev=0.05),
dtype=tf.float32)
biases = tf.get_variable(name="biases",
shape=[2048],
initializer=tf.constant_initializer(value=0.),
dtype=tf.float32)
output = tf.nn.xw_plus_b(x=reshape,
weights=weights,
biases=biases)
conn = tf.nn.relu(output)
with tf.variable_scope("output"):
weights = tf.get_variable(name="weights",
shape=[2048, NUMBERS * CLASSES],
initializer=tf.truncated_normal_initializer(stddev=0.05),
dtype=tf.float32)
biases = tf.get_variable(name="biases",
shape=[NUMBERS * CLASSES],
initializer=tf.constant_initializer(value=0.),
dtype=tf.float32)
logits = tf.nn.xw_plus_b(x=conn,
weights=weights,
biases=biases)
reshape = tf.reshape(logits, shape=[BATCH_SIZE, NUMBERS, CLASSES])
return reshape


def loss(logits, labels):
cross_entropy_per_number = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
cross_entropy = tf.reduce_mean(cross_entropy_per_number)
tf.add_to_collection("loss", cross_entropy)
return cross_entropy


def evaluation(logits, labels):
prediction = tf.argmax(logits, 2)
actual = tf.argmax(labels, 2)
equal = tf.equal(prediction, actual)
# equal = tf.reduce_all(equal, 1)
accuracy = tf.reduce_mean(tf.cast(equal, tf.float32), name="accuracy")
return accuracy


def train(loss, learning_rate=0.00001):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(loss)
return train_op

最佳答案

你如何保存它?你有没有尝试过:(用于保存)

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')

(用于加载)

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))

官方引用:https://www.tensorflow.org/versions/master/api_docs/python/state_ops/exporting_and_importing_meta_graphs (或者将 URL 中的 master 替换为版本号,例如 r0.12)。

关于python - 恢复和评估 Tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41984876/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com