gpt4 book ai didi

python - 批量读取数据集进行训练

转载 作者:太空宇宙 更新时间:2023-11-03 19:52:37 29 4
gpt4 key购买 nike

我正在尝试读取 cifar10 数据集并将其用于训练模型,因此我尝试读取批处理并运行 session ,如下所示:

 # Optimizer
opt = tf.train.AdamOptimizer(0.0001)
global_step = tf.get_variable('global_step', initializer=tf.constant(0), trainable=False)
train_op = opt.apply_gradients(zip(grads, var_list), global_step=global_step)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

image_batch, label_batch = tf.train.batch([x_train, y_train], batch_size=batch_size)
#image_batch_uint8 = tf.cast(image_batch, tf.uint8)

# Train
with tf.Session() as sess:

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sess.run(tf.global_variables_initializer())
for i in range(10000000):
_loss_value, _reward_value, _ = sess.run([loss, reward, train_op], feed_dict={
images_ph: image_batch,
labels_ph: label_batch
})
if i % 100 == 0:
print('iter: ', i, '\tloss: ', _loss_value, '\treward: ', _reward_value)

但是我收到此错误:

 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1086, in _run
'feed with key ' + str(feed) + '.')
The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("batch:0", shape=(32, 50000, 32, 32, 3), dtype=uint8) which was passed to the feed with key Tensor("Placeholder:0", shape=(?, 1024), dtype=float32).

我做错了什么?我如何确保将所有数据集作为纪元提供,是否有更直接的方法来提供数据集?

最佳答案

错误是因为变量image_batchlabel_batch是张量。 feed 字典的语法为 {tensor1:value1,tensor2:value2.....}。因此,您需要输入 numpy 数组来代替 value1,value2..

所以你只需要执行value1,value2 = sess.run([image_batch,label_batch])

总体变化如下:

.
.
for i in range(10000000):

try:

raw_images, raw_labels = sess.run([image_batch, label_batch])
_loss_value, _reward_value, _ = sess.run([loss, reward, train_op], feed_dict={image_batch: raw_images, label_batch: raw_labels})

except tf.errors.OutOfRangeError:
print("Breaking...")
break
.
.
if i % 100 == 0:
print('iter: ', i, '\tloss: ', _loss_value, '\treward: ', _reward_value)

我认为使用 tf.train.Coordinator() 而不是我编写的 try..except block ,您还可以使用以下 block (在他们的网站上) :

try:
while not coord.should_stop():
...do some work...
except Exception as e:
coord.request_stop(e)

关于python - 批量读取数据集进行训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59736714/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com