gpt4 book ai didi

Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 27 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

单一数据读取方式:

  第一种:slice_input_producer() 。

?
1
2
# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]
[images, labels] = tf.train.slice_input_producer([images, labels], num_epochs = None , shuffle = True )

  第二种:string_input_producer() 。

?
1
2
3
4
5
# 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看
file_queue = tf.train.string_input_producer(filename, num_epochs = None , shuffle = True )
 
reader = tf.WholeFileReader()      # 定义文件读取器
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容

  !!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小).

  !!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量 。

  !!!以上两种方法都可以生成文件名队列.

(随机)批量数据读取方式:

?
1
2
3
batchsize = 2    # 每次读取的样本数量
tf.train.batch(tensors, batch_size = batchsize)
tf.train.shuffle_batch(tensors, batch_size = batchsize, capacity = batchsize * 10 , min_after_dequeue = batchsize * 5 ) # capacity > min_after_dequeue

  !!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners() 。

 TFRecord文件的打包与读取 。

 1、单一数据读取方式 。

第一种:slice_input_producer() 。

?
1
def slice_input_producer(tensor_list, num_epochs = None , shuffle = True , seed = None , capacity = 32 , shared_name = None , name = None )

案例1:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
 
images = [ 'image1.jpg' , 'image2.jpg' , 'image3.jpg' , 'image4.jpg' ]
labels = [ 1 , 2 , 3 , 4 ]
 
# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
 
# 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)
 
data = tf.train.slice_input_producer([images, labels], num_epochs = None , shuffle = True )
print ( type (data))  # <class 'list'>
 
with tf.Session() as sess:
   # sess.run(tf.local_variables_initializer())
   sess.run(tf.local_variables_initializer())
   coord = tf.train.Coordinator() # 线程的协调器
   threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器
 
   for i in range ( 10 ):
     print (sess.run(data))
 
   coord.request_stop()
   coord.join(threads)
 
"""

运行结果:

[b'image2.jpg', 2] [b'image1.jpg', 1] [b'image3.jpg', 3] [b'image4.jpg', 4] [b'image2.jpg', 2] [b'image1.jpg', 1] [b'image3.jpg', 3] [b'image4.jpg', 4] [b'image2.jpg', 2] [b'image3.jpg', 3] """ 。

  !!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels], 。

  !!!num_epochs设置 。

 第二种:string_input_producer() 。

?
1
def string_input_producer(string_tensor, num_epochs = None , shuffle = True , seed = None , capacity = 32 , shared_name = None , name = None , cancel_op = None )

文件读取器 。

  不同类型的文件对应不同的文件读取器,我们称为 reader对象; 。

  该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容; 。

?
1
2
3
4
5
reader = tf.TextLineReader()   ### 一行一行读取,适用于所有文本文件
 
reader = tf.TFRecordReader()   ### A Reader that outputs the records from a TFRecords file
 
reader = tf.WholeFileReader()   ### 一次读取整个文件,适用图片

案例2:读取csv文件 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
 
filename = [ 'data/A.csv' , 'data/B.csv' , 'data/C.csv' ]
 
file_queue = tf.train.string_input_producer(filename, shuffle = True , num_epochs = 2 # 生成文件名队列
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
# reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
key, value = reader.read(file_queue)  # key:文件名;value:文件中的内容
print ( type (file_queue))
 
init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
with tf.Session() as sess:
   sess.run(init)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess = sess, coord = coord)
   try :
     while not coord.should_stop():
       for i in range ( 6 ):
         print (sess.run([key, value]))
       break
   except tf.errors.OutOfRangeError:
     print ( 'read done' )
   finally :
     coord.request_stop()
   coord.join(threads)
 
"""
reader = tf.WholeFileReader()      # 定义文件读取器(一次读取整个文件)
运行结果:
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
"""
"""
reader = tf.TextLineReader()      # 定义文件读取器(一行一行的读)
运行结果:
[b'data/B.csv:1', b'4.jpg,4']
[b'data/B.csv:2', b'5.jpg,5']
[b'data/B.csv:3', b'6.jpg,6']
[b'data/C.csv:1', b'7.jpg,7']
[b'data/C.csv:2', b'8.jpg,8']
[b'data/C.csv:3', b'9.jpg,9']
"""

案例3:读取图片(每次读取全部图片内容,不是一行一行) 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
 
filename = [ '1.jpg' , '2.jpg' ]
filename_queue = tf.train.string_input_producer(filename, shuffle = False , num_epochs = 1 )
reader = tf.WholeFileReader()       # 文件读取器
key, value = reader.read(filename_queue)  # 读取文件 key:文件名;value:图片数据,bytes
 
with tf.Session() as sess:
   tf.local_variables_initializer().run()
   coord = tf.train.Coordinator()   # 线程的协调器
   threads = tf.train.start_queue_runners(sess, coord)
 
   for i in range (filename.__len__()):
     image_data = sess.run(value)
     with open ( 'img_%d.jpg' % i, 'wb' ) as f:
       f.write(image_data)
   coord.request_stop()
   coord.join(threads)

 2、(随机)批量数据读取方式:

  功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似; 。

案例4:slice_input_producer() 与 batch() 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
import numpy as np
 
images = np.arange( 20 ).reshape([ 10 , 2 ])
label = np.asarray( range ( 0 , 10 ))
images = tf.cast(images, tf.float32)   # 可以注释掉,不影响运行结果
label = tf.cast(label, tf.int32)     # 可以注释掉,不影响运行结果
 
batchsize = 6  # 每次获取元素的数量
input_queue = tf.train.slice_input_producer([images, label], num_epochs = None , shuffle = False )
image_batch, label_batch = tf.train.batch(input_queue, batch_size = batchsize)
 
# 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
# image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10)
 
with tf.Session() as sess:
   coord = tf.train.Coordinator()   # 线程的协调器
   threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
   for cnt in range ( 2 ):
     print ( "第{}次获取数据,每次batch={}..." . format (cnt + 1 , batchsize))
     image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
     print (image_batch_v, label_batch_v, label_batch_v.__len__())
 
   coord.request_stop()
   coord.join(threads)
 
"""

运行结果: 第1次获取数据,每次batch=6... [[ 0.  1.]  [ 2.  3.]  [ 4.  5.]  [ 6.  7.]  [ 8.  9.]  [10. 11.]] [0 1 2 3 4 5] 6 第2次获取数据,每次batch=6... [[12. 13.]  [14. 15.]  [16. 17.]  [18. 19.]  [ 0.  1.]  [ 2.  3.]] [6 7 8 9 0 1] 6 """ 。

 案例5:从本地批量的读取图片 --- string_input_producer() 与 batch() 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tensorflow as tf
import glob
import cv2 as cv
 
def read_imgs(filename, picture_format, input_image_shape, batch_size = ):
   """
   从本地批量的读取图片
   :param filename: 图片路径(包括图片的文件名),[]
   :param picture_format: 图片的格式,如 bmp,jpg,png等; string
   :param input_image_shape: 输入图像的大小; (h,w,c)或[]
   :param batch_size: 每次从文件队列中加载图片的数量; int
   :return: batch_size张图片数据, Tensor
   """
   global new_img
   # 创建文件队列
   file_queue = tf.train.string_input_producer(filename, num_epochs = 1 , shuffle = True )
   # 创建文件读取器
   reader = tf.WholeFileReader()
   # 读取文件队列中的文件
   _, img_bytes = reader.read(file_queue)
   # print(img_bytes)  # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
   # 对图片进行解码
   if picture_format = = ".bmp" :
     new_img = tf.image.decode_bmp(img_bytes, channels = 1 )
   elif picture_format = = ".jpg" :
     new_img = tf.image.decode_jpeg(img_bytes, channels = 3 )
   else :
     pass
   # 重新设置图片的大小
   # new_img = tf.image.resize_images(new_img, input_image_shape)
   new_img = tf.reshape(new_img, input_image_shape)
   # 设置图片的数据类型
   new_img = tf.image.convert_image_dtype(new_img, tf.uint)
 
   # return new_img
   return tf.train.batch([new_img], batch_size)
 
 
def main():
   image_path = glob.glob(r 'F:\demo\FaceRecognition\人脸库\ORL\*.bmp' )
   image_batch = read_imgs(image_path, ".bmp" , ( 112 , 92 , 1 ), 5 )
   print ( type (image_batch))
   # image_path = glob.glob(r'.\*.jpg')
   # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1)
 
   sess = tf.Session()
   sess.run(tf.local_variables_initializer())
   tf.train.start_queue_runners(sess = sess)
 
   image_batch = sess.run(image_batch)
   print ( type (image_batch))  # <class 'numpy.ndarray'>
 
   for i in range (image_batch.__len__()):
     cv.imshow( "win_" + str (i), image_batch[i])
   cv.waitKey()
   cv.destroyAllWindows()
 
def start():
   image_path = glob.glob(r 'F:\demo\FaceRecognition\人脸库\ORL\*.bmp' )
   image_batch = read_imgs(image_path, ".bmp" , ( 112 , 92 , 1 ), 5 )
   print ( type (image_batch))  # <class 'tensorflow.python.framework.ops.Tensor'>
 
 
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()   # 线程的协调器
     threads = tf.train.start_queue_runners(sess, coord)   # 开始在图表中收集队列运行器
     image_batch = sess.run(image_batch)
     print ( type (image_batch))  # <class 'numpy.ndarray'>
 
     for i in range (image_batch.__len__()):
       cv.imshow( "win_" + str (i), image_batch[i])
     cv.waitKey()
     cv.destroyAllWindows()
 
     # 若使用 with 方式打开 Session,且没加如下行语句,则会出错
     # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
     # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
     coord.request_stop()
     coord.join(threads)
 
 
if __name__ = = "__main__" :
   # main()
   start()

案列6:TFRecord文件打包与读取 。

 TFRecord文件打包案列 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def write_TFRecord(filename, data, labels, is_shuffler = True ):
   """
   将数据打包成TFRecord格式
   :param filename: 打包后路径名,默认在工程目录下创建该文件;String
   :param data: 需要打包的文件路径名;list
   :param labels: 对应文件的标签;list
   :param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
   :return: None
   """
   im_data = list (data)
   im_labels = list (labels)
 
   index = [i for i in range (im_data.__len__())]
   if is_shuffler:
     np.random.shuffle(index)
 
   # 创建写入器,然后使用该对象写入样本example
   writer = tf.python_io.TFRecordWriter(filename)
   for i in range (im_data.__len__()):
     im_d = im_data[index[i]]  # im_d:存放着第index[i]张图片的路径信息
     im_l = im_labels[index[i]] # im_l:存放着对应图片的标签信息
 
     # # 获取当前的图片数据 方式一:
     # data = cv2.imread(im_d)
     # # 创建样本
     # ex = tf.train.Example(
     #   features=tf.train.Features(
     #     feature={
     #       "image": tf.train.Feature(
     #         bytes_list=tf.train.BytesList(
     #           value=[data.tobytes()])), # 需要打包成bytes类型
     #       "label": tf.train.Feature(
     #         int64_list=tf.train.Int64List(
     #           value=[im_l])),
     #     }
     #   )
     # )
     # 获取当前的图片数据 方式二:相对于方式一,打包文件占用空间小了一半多
     data = tf.gfile.FastGFile(im_d, "rb" ).read()
     ex = tf.train.Example(
       features = tf.train.Features(
         feature = {
           "image" : tf.train.Feature(
             bytes_list = tf.train.BytesList(
               value = [data])), # 此时的data已经是bytes类型
           "label" : tf.train.Feature(
             int_list = tf.train.IntList(
               value = [im_l])),
         }
       )
     )
 
     # 写入将序列化之后的样本
     writer.write(ex.SerializeToString())
   # 关闭写入器
   writer.close()

TFReord文件的读取案列 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
import cv2
 
def read_TFRecord(file_list, batch_size = ):
   """
   读取TFRecord文件
   :param file_list: 存放TFRecord的文件名,List
   :param batch_size: 每次读取图片的数量
   :return: 解析后图片及对应的标签
   """
   file_queue = tf.train.string_input_producer(file_list, num_epochs = None , shuffle = True )
   reader = tf.TFRecordReader()
   _, ex = reader.read(file_queue)
   batch = tf.train.shuffle_batch([ex], batch_size, capacity = batch_size * 10 , min_after_dequeue = batch_size * 5 )
 
   feature = {
     'image' : tf.FixedLenFeature([], tf.string),
     'label' : tf.FixedLenFeature([], tf.int64)
   }
   example = tf.parse_example(batch, features = feature)
 
   images = tf.decode_raw(example[ 'image' ], tf.uint)
   images = tf.reshape(images, [ - 1 , 32 , 32 , 3 ])
 
   return images, example[ 'label' ]
 
 
 
def main():
   # filelist = ['data/train.tfrecord']
   filelist = [ 'data/test.tfrecord' ]
   images, labels = read_TFRecord(filelist, 2 )
   with tf.Session() as sess:
     sess.run(tf.local_variables_initializer())
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
     try :
       while not coord.should_stop():
         for i in range ():
           image_bth, _ = sess.run([images, labels])
           print (_)
 
           cv2.imshow( "image_0" , image_bth[ 0 ])
           cv2.imshow( "image_1" , image_bth[ 1 ])
         break
     except tf.errors.OutOfRangeError:
       print ( 'read done' )
     finally :
       coord.request_stop()
     coord.join(threads)
     cv2.waitKey( 0 )
     cv2.destroyAllWindows()
 
if __name__ = = "__main__" :
   main()

到此这篇关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的文章就介绍到这了,更多相关Tensorflow TFRecord打包与读取内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。

原文链接:https://www.cnblogs.com/nbk-zyc/p/13159986.html 。

最后此篇关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的文章就讲到这里了,如果你想了解更多关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com