gpt4 book ai didi

python - Tensorflow - 如何导入 MNIST 数据库

转载 作者:行者123 更新时间:2023-12-04 13:45:13 24 4
gpt4 key购买 nike

我想使用 MNIST 数据库训练模型。我正在编写 Tensorflow 教程 Tensorflow tutorial .导入数据库的建议方法是使用 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)但我需要使用类似的东西:

with open('my/directory/train-images-idx3-ubyte.gz', 'rb') as f:
train_images = extract_images(f)
with open('my/directory/train-labels-idx1-ubyte.gz', 'rb') as f:
train_labels = extract_images(f)
...

这带来了如何调整代码以与我的 train_images, train_lables, test_images, test_lables 一起使用的问题:
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

# Build the graph for the deep net
y_conv, keep_prob = deepnn(x)

with tf.name_scope('loss'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
logits=y_conv)
cross_entropy = tf.reduce_mean(cross_entropy)

with tf.name_scope('adam_optimizer'):
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)

graph_location = tempfile.mkdtemp()
print('Saving graph to: %s' % graph_location)
train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph())

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i % 100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0})
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

最佳答案

使用 TF2 导入 mnist 数据集的推荐方法如下:
从 tensorflow.keras.datasets 导入 mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

关于python - Tensorflow - 如何导入 MNIST 数据库,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50203272/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com