gpt4 book ai didi

python - TensorFlow 从 mnist 数据集中选择标签

转载 作者:行者123 更新时间:2023-11-28 17:15:22 24 4
gpt4 key购买 nike

我正在使用 tensorflow.examples.tutorials.mnist 来训练具有 5 个隐藏层的神经网络。

这是我训练神经网络的方式:

with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(len(mnist.test.labels)//batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

我想训练神经网络仅识别 0 到 4 之间的数字。我将 logits 层更改为具有 5 个输出。

如何过滤 TensorFlow 提供的 mnist 数据集,以便只获取 0 到 4 之间的数字?

最佳答案

有很多方法可以做到这一点。其中之一是当您提取 X_batch, y_batch = mnist.train.next_batch(batch_size) 时。在此步骤中,您的 y_batch 将具有有关数字值的信息(数字值或数字的单热值)。

您迭代批处理中的示例并检查数字是否是您关心的数字。如果是,则将其添加到 cleaned_up_batch。效率不是很高,但它会起作用。


回复评论:

效率不高,因为您可能需要多次过滤相同的数据。我认为这不会成为问题,因为 MNIST 非常小。通常的做法是只过滤一次,创建一个新的数据集并编写自己的函数从中获取下一批(实际上很容易,因为您只是从数据集中随机选择 k 个元素)

关于python - TensorFlow 从 mnist 数据集中选择标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44532036/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com