gpt4 book ai didi

浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 29 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size 。

dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合 。

dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中.

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
os.environ[ 'CUDA_VISIBLE_DEVICES' ] = ""
import numpy as np
import tensorflow as tf
np.random.seed( 0 )
x = np.random.sample(( 11 , 2 ))
# make a dataset from a numpy array
print (x)
print ()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle( 3 )
dataset = dataset.batch( 4 )
dataset = dataset.repeat( 2 )
 
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter .get_next()
 
with tf.Session() as sess:
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#源数据集
[[ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]
  [ 0.43758721 0.891773 ]
  [ 0.96366276 0.38344152 ]
  [ 0.79172504 0.52889492 ]
  [ 0.56804456 0.92559664 ]
  [ 0.07103606 0.0871293 ]
  [ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.97861834 0.79915856 ]]
 
# 通过shuffle batch后取得的样本
[[ 0.4236548  0.64589411 ]
  [ 0.60276338 0.54488318 ]
  [ 0.43758721 0.891773 ]
  [ 0.5488135  0.71518937 ]]
[[ 0.96366276 0.38344152 ]
  [ 0.56804456 0.92559664 ]
  [ 0.0202184  0.83261985 ]
  [ 0.79172504 0.52889492 ]]
[[ 0.07103606 0.0871293 ]
  [ 0.97861834 0.79915856 ]
  [ 0.77815675 0.87001215 ]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318 ]
  [ 0.5488135  0.71518937 ]
  [ 0.43758721 0.891773 ]
  [ 0.79172504 0.52889492 ]]
[[ 0.4236548  0.64589411 ]
  [ 0.56804456 0.92559664 ]
  [ 0.0202184  0.83261985 ]
  [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215 ]
  [ 0.96366276 0.38344152 ]
  [ 0.97861834 0.79915856 ]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] 2、从buffer中取一个样本到batch中得: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] batch: [ 0.4236548 0.64589411] 3、shuffle buffer不足三个样本,从源数据集提取一个样本: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.43758721 0.891773 ] 4、从buffer中取一个样本到batch中得: shuffle buffer: [ 0.5488135 0.71518937] [ 0.43758721 0.891773 ] batch: [ 0.4236548 0.64589411] [ 0.60276338 0.54488318] 5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
os.environ[ 'CUDA_VISIBLE_DEVICES' ] = ""
import numpy as np
import tensorflow as tf
np.random.seed( 0 )
x = np.random.sample(( 11 , 2 ))
# make a dataset from a numpy array
print (x)
print ()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle( 1 )
dataset = dataset.batch( 4 )
dataset = dataset.repeat( 2 )
 
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter .get_next()
 
with tf.Session() as sess:
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
 
[[ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]
  [ 0.43758721 0.891773 ]
  [ 0.96366276 0.38344152 ]
  [ 0.79172504 0.52889492 ]
  [ 0.56804456 0.92559664 ]
  [ 0.07103606 0.0871293 ]
  [ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.97861834 0.79915856 ]]
 
[[ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]
  [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152 ]
  [ 0.79172504 0.52889492 ]
  [ 0.56804456 0.92559664 ]
  [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.97861834 0.79915856 ]]
[[ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]
  [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152 ]
  [ 0.79172504 0.52889492 ]
  [ 0.56804456 0.92559664 ]
  [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.97861834 0.79915856 ]]

注意如果repeat在shuffle之前使用:

官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
os.environ[ 'CUDA_VISIBLE_DEVICES' ] = ""
import numpy as np
import tensorflow as tf
np.random.seed( 0 )
x = np.random.sample(( 11 , 2 ))
# make a dataset from a numpy array
print (x)
print ()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat( 2 )
dataset = dataset.shuffle( 11 )
dataset = dataset.batch( 4 )
 
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter .get_next()
 
with tf.Session() as sess:
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
   print (sess.run(el))
 
[[ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]
  [ 0.43758721 0.891773 ]
  [ 0.96366276 0.38344152 ]
  [ 0.79172504 0.52889492 ]
  [ 0.56804456 0.92559664 ]
  [ 0.07103606 0.0871293 ]
  [ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.97861834 0.79915856 ]]
 
[[ 0.56804456 0.92559664 ]
  [ 0.5488135  0.71518937 ]
  [ 0.60276338 0.54488318 ]
  [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152 ]
  [ 0.43758721 0.891773 ]
  [ 0.43758721 0.891773 ]
  [ 0.77815675 0.87001215 ]]
[[ 0.79172504 0.52889492 #出现相同样本出现在同一个batch中
  [ 0.79172504 0.52889492 ]
  [ 0.60276338 0.54488318 ]
  [ 0.4236548  0.64589411 ]]
[[ 0.07103606 0.0871293 ]
  [ 0.4236548  0.64589411 ]
  [ 0.96366276 0.38344152 ]
  [ 0.5488135  0.71518937 ]]
[[ 0.97861834 0.79915856 ]
  [ 0.0202184  0.83261985 ]
  [ 0.77815675 0.87001215 ]
  [ 0.56804456 0.92559664 ]]
[[ 0.0202184  0.83261985 ]
  [ 0.97861834 0.79915856 ]]     #可以看到最后个batch为2,而前面都是4

使用案例:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def input_fn(filenames, batch_size = 32 , num_epochs = 1 , perform_shuffle = False ):
   print ( 'Parsing' , filenames)
   def decode_libsvm(line):
     #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
     #features = dict(zip(CSV_COLUMNS, columns))
     #labels = features.pop(LABEL_COLUMN)
     columns = tf.string_split([line], ' ' )
     labels = tf.string_to_number(columns.values[ 0 ], out_type = tf.float32)
     splits = tf.string_split(columns.values[ 1 :], ':' )
     id_vals = tf.reshape(splits.values,splits.dense_shape)
     feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits = 2 ,axis = 1 )
     feat_ids = tf.string_to_number(feat_ids, out_type = tf.int32)
     feat_vals = tf.string_to_number(feat_vals, out_type = tf.float32)
     #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
     #for i in range(splits.dense_shape.eval()[0]):
     #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
     #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
     #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
     return { "feat_ids" : feat_ids, "feat_vals" : feat_vals}, labels
 
   # Extract lines from input files using the Dataset API, can pass one filename or filename list
   dataset = tf.data.TextLineDataset(filenames). map (decode_libsvm, num_parallel_calls = 10 ).prefetch( 500000 # multi-thread pre-process then prefetch
 
   # Randomizes input using a window of 256 elements (read into memory)
   if perform_shuffle:
     dataset = dataset.shuffle(buffer_size = 256 )
 
   # epochs from blending together.
   dataset = dataset.repeat(num_epochs)
   dataset = dataset.batch(batch_size) # Batch size to use
 
   #return dataset.make_one_shot_iterator()
   iterator = dataset.make_one_shot_iterator()
   batch_features, batch_labels = iterator.get_next()
   #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
   return batch_features, batch_labels

到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我!  。

原文链接:https://blog.csdn.net/qq_16234613/article/details/81703228 。

最后此篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就讲到这里了,如果你想了解更多关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com