gpt4 book ai didi

machine-learning - Tensorflow 理解 tf.train.shuffle_batch

转载 作者:行者123 更新时间:2023-11-30 08:29:17 24 4
gpt4 key购买 nike

我有一个训练数据文件,大约有 100K 行,并且我在每个训练步骤上运行一个简单的 tf.train.GradientDescentOptimizer。该设置本质上直接取自 Tensorflow 的 MNIST 示例。代码转载如下:

x = tf.placeholder(tf.float32, [None, 21])
W = tf.Variable(tf.zeros([21, 2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 2])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

鉴于我正在从文件中读取训练数据,我正在使用 tf.train.string_input_ Producer 和 tf.decode_csv 从 csv 中读取行,并且然后tf.train.shuffle_batch创建批处理,然后进行训练。

我对 tf.train.shuffle_batch 的参数应该是什么感到困惑。我阅读了 Tensorflow 的文档,但我仍然不确定“最佳”batch_size、capacity 和 min_after_dequeue 值是多少。任何人都可以帮助阐明我如何为这些参数选择正确的值,或者将我链接到可以了解更多信息的资源吗?谢谢——

这是 API 链接:https://www.tensorflow.org/versions/r0.9/api_docs/python/io_ops.html#shuffle_batch

最佳答案

有一点关于要使用的线程数

https://www.tensorflow.org/versions/r0.9/how_tos/reading_data/index.html#batching

不幸的是,我认为批量大小没有一个简单的答案。网络的有效批量大小取决于很多细节关于网络。在实践中,如果您关心最佳性能你将需要进行大量的试验和错误(也许开始来自类似网络使用的值)。

关于machine-learning - Tensorflow 理解 tf.train.shuffle_batch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38033079/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com