gpt4 book ai didi

python - Tensorflow - 我是否正确恢复模型?

转载 作者:行者123 更新时间:2023-11-30 22:39:44 25 4
gpt4 key购买 nike

我有以下正在运行的代码(没有错误)。我的问题是我是否正确恢复模型?特别是我看不到语句 print(v_) 的任何输出。

所以,我想知道我是否正确执行了以下操作:

  1. 恢复模型
  2. 使用恢复的模型

    将 tensorflow 导入为 tf

    data, labels = cifar_tools.read_data('C:\\Users\\abc\\Desktop\\Testing')

    x = tf.placeholder(tf.float32, [None, 150 * 150])
    y = tf.placeholder(tf.float32, [None, 2])

    w1 = tf.Variable(tf.random_normal([5, 5, 1, 64]))
    b1 = tf.Variable(tf.random_normal([64]))

    w2 = tf.Variable(tf.random_normal([5, 5, 64, 64]))
    b2 = tf.Variable(tf.random_normal([64]))

    w3 = tf.Variable(tf.random_normal([38*38*64, 1024]))
    b3 = tf.Variable(tf.random_normal([1024]))

    w_out = tf.Variable(tf.random_normal([1024, 2]))
    b_out = tf.Variable(tf.random_normal([2]))

    def conv_layer(x,w,b):
    conv = tf.nn.conv2d(x,w,strides=[1,1,1,1], padding = 'SAME')
    conv_with_b = tf.nn.bias_add(conv,b)
    conv_out = tf.nn.relu(conv_with_b)
    return conv_out

    def maxpool_layer(conv,k=2):
    return tf.nn.max_pool(conv, ksize=[1,k,k,1], strides=[1,k,k,1], padding='SAME')

    def model():
    x_reshaped = tf.reshape(x, shape=[-1, 150, 150, 1])

    conv_out1 = conv_layer(x_reshaped, w1, b1)
    maxpool_out1 = maxpool_layer(conv_out1)
    norm1 = tf.nn.lrn(maxpool_out1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
    conv_out2 = conv_layer(norm1, w2, b2)
    norm2 = tf.nn.lrn(conv_out2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
    maxpool_out2 = maxpool_layer(norm2)

    maxpool_reshaped = tf.reshape(maxpool_out2, [-1, w3.get_shape().as_list()[0]])
    local = tf.add(tf.matmul(maxpool_reshaped, w3), b3)
    local_out = tf.nn.relu(local)

    out = tf.add(tf.matmul(local_out, w_out), b_out)
    return out

    model_op = model()

    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(model_op, y))
    train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

    correct_pred = tf.equal(tf.argmax(model_op, 1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

    with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    onehot_labels = tf.one_hot(labels, 2, on_value=1.,off_value=0.,axis=-1)
    onehot_vals = sess.run(onehot_labels)
    batch_size = len(data)
    # Restore model
    saver = tf.train.import_meta_graph('C:\\Users\\abc\\Desktop\\\Testing\\mymodel.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
    v_ = sess.run(v)
    print(v_)

    for j in range(0, 5):
    print('EPOCH', j)
    for i in range(0, len(data), batch_size):
    batch_data = data[i:i+batch_size, :]
    batch_onehot_vals = onehot_vals[i:i+batch_size, :]
    _, accuracy_val = sess.run([train_op, accuracy], feed_dict={x: batch_data, y: batch_onehot_vals})
    print(i, accuracy_val)

    print('DONE WITH EPOCH')

编辑 1

这样恢复有用吗?

saver = tf.train.Saver()
saver = tf.train.import_meta_graph('C:\\Users\\Abder-Rahman\\Desktop\\\Testing\\mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print('model restored'

编辑2

这就是我保存模型的方式:

#Save model
saver = tf.train.Saver()
saved_path = saver.save(sess, 'C:\\Users\\abc\\Desktop\\\Testing\\mymodel')
print("The model is in this file: ", saved_path)

谢谢。

最佳答案

您的保护码是正确的。而变量必须在检索集合之前添加到集合中。
tf.add_to_collection("vars", w1)
tf.add_to_collection("vars", b1)
...
然后 all_vars = tf.get_collection('vars')

关于python - Tensorflow - 我是否正确恢复模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43057816/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com