gpt4 book ai didi

python - 如何使用生成器训练 TensorFlow 网络以产生输入?

转载 作者:IT老高 更新时间:2023-10-28 20:33:57 26 4
gpt4 key购买 nike

TensorFlow docs描述了一系列使用 TFRecordReader、TextLineReader、QueueRunner 等和队列读取数据的方法。

我想做的事情要简单得多:我有一个 python 生成器函数,它生成无限序列的训练数据作为 (X, y) 元组(都是 numpy 数组,第一个维度是批量大小)。我只想使用该数据作为输入来训练网络。

是否有使用生成数据的生成器来训练 TensorFlow 网络的简单自包含示例? (沿用 MNIST 或 CIFAR 示例)

最佳答案

假设你有一个生成数据的函数:

 def generator(data): 
...
yield (X, y)

现在您需要另一个函数来描述您的模型架构。它可以是任何处理 X 并且必须将 y 预测为输出的函数(例如神经网络)。

假设您的函数接受 X 和 y 作为输入,以某种方式从 X 计算 y 的预测,并在 y 和预测的 y 之间返回损失函数(例如交叉熵或 MSE):

 def neural_network(X, y): 
# computation of prediction for y using X
...
return loss(y, y_pred)

要使您的模型正常工作,您需要为 X 和 y 定义占位符,然后运行 session :

 X = tf.placeholder(tf.float32, shape=(batch_size, x_dim))
y = tf.placeholder(tf.float32, shape=(batch_size, y_dim))

占位符类似于“自由变量”,您需要在 feed_dict 运行 session 时指定它:

 with tf.Session() as sess:
# variables need to be initialized before any sess.run() calls
tf.global_variables_initializer().run()

for X_batch, y_batch in generator(data):
feed_dict = {X: X_batch, y: y_batch}
_, loss_value, ... = sess.run([train_op, loss, ...], feed_dict)
# train_op here stands for optimization operation you have defined
# and loss for loss function (return value of neural_network function)

希望你会发现它有用。但是,请记住,这不是完全有效的实现,而是一个伪代码,因为您几乎没有指定任何细节。

关于python - 如何使用生成器训练 TensorFlow 网络以产生输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39325275/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com