gpt4 book ai didi

tensorflow tf.train.batch之数据批量读取方式

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 28 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章tensorflow tf.train.batch之数据批量读取方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇文章的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止.

然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次.

于是我设置之后出现以下的报错:

?
1
2
3
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer / input_producer / limit_epochs / epochs
 
      [[Node: input_producer / input_producer / limit_epochs / CountUpTo = CountUpTo[T = DT_INT64, _class = [ "loc:@input_producer/input_producer/limit_epochs/epochs" ], limit = 2 , _device = "/job:localhost/replica:0/task:0/cpu:0" ](input_producer / input_producer / limit_epochs / epochs)]]

找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化.

于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因.

哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pandas as pd
import numpy as np
import tensorflow as tf
 
 
def generate_data():
   num = 25
   label = np.asarray( range ( 0 , num))
   images = np.random.random([num, 5 ])
   print ( 'label size :{}, image size {}' . format (label.shape, images.shape))
   return images,label
 
def get_batch_data():
   label, images = generate_data()
   input_queue = tf.train.slice_input_producer([images, label], shuffle = False ,num_epochs = 2 )
   image_batch, label_batch = tf.train.batch(input_queue, batch_size = 5 , num_threads = 1 , capacity = 64 ,allow_smaller_final_batch = False )
   return image_batch,label_batch
 
 
images,label = get_batch_data()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) #就是这一行
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try :
   while not coord.should_stop():
     i,l = sess.run([images,label])
     print (i)
     print (l)
except tf.errors.OutOfRangeError:
   print ( 'Done training' )
finally :
   coord.request_stop()
coord.join(threads)
sess.close()

以上这篇tensorflow tf.train.batch之数据批量读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.

原文链接:https://blog.csdn.net/liweibin1994/article/details/78306417 。

最后此篇关于tensorflow tf.train.batch之数据批量读取方式的文章就讲到这里了,如果你想了解更多关于tensorflow tf.train.batch之数据批量读取方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com