gpt4 book ai didi

python-3.x - TensorFlow:dataset.train.next_batch 是如何定义的?

转载 作者:行者123 更新时间:2023-12-03 21:20:47 27 4
gpt4 key购买 nike

我正在尝试学习 TensorFlow 并在以下位置学习示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后我在下面的代码中有一些问题:

for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", "{:.9f}".format(c))

由于 mnist 只是一个数据集, mnist.train.next_batch 到底是什么?意思是? dataset.train.next_batch 怎么样?定义?

谢谢!

最佳答案

mnist对象从 read_data_sets() function 返回在 tf.contrib.learn 中定义模块。 mnist.train.next_batch(batch_size)方法实现here ,它返回两个数组的元组,其中第一个表示一批 batch_size MNIST 图像,第二个代表一批 batch-size对应于这些图像的标签。

图像以大小为 [batch_size, 784] 的二维 NumPy 数组形式返回。 (因为 MNIST 图像中有 784 个像素),并且标签作为大小为 [batch_size] 的一维 NumPy 数组返回。 (如果 read_data_sets() 是用 one_hot=False 调用的)或大小为 [batch_size, 10] 的二维 NumPy 数组(如果 read_data_sets() 是用 one_hot=True 调用的)。

关于python-3.x - TensorFlow:dataset.train.next_batch 是如何定义的?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41454511/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com