- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size 。
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合 。
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
import
os
os.environ[
'CUDA_VISIBLE_DEVICES'
]
=
""
import
numpy as np
import
tensorflow as tf
np.random.seed(
0
)
x
=
np.random.sample((
11
,
2
))
# make a dataset from a numpy array
print
(x)
print
()
dataset
=
tf.data.Dataset.from_tensor_slices(x)
dataset
=
dataset.shuffle(
3
)
dataset
=
dataset.batch(
4
)
dataset
=
dataset.repeat(
2
)
# create the iterator
iter
=
dataset.make_one_shot_iterator()
el
=
iter
.get_next()
with tf.Session() as sess:
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
#源数据集
[[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]
[
0.43758721
0.891773
]
[
0.96366276
0.38344152
]
[
0.79172504
0.52889492
]
[
0.56804456
0.92559664
]
[
0.07103606
0.0871293
]
[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.97861834
0.79915856
]]
# 通过shuffle batch后取得的样本
[[
0.4236548
0.64589411
]
[
0.60276338
0.54488318
]
[
0.43758721
0.891773
]
[
0.5488135
0.71518937
]]
[[
0.96366276
0.38344152
]
[
0.56804456
0.92559664
]
[
0.0202184
0.83261985
]
[
0.79172504
0.52889492
]]
[[
0.07103606
0.0871293
]
[
0.97861834
0.79915856
]
[
0.77815675
0.87001215
]]
#最后一个batch样本个数为3
[[
0.60276338
0.54488318
]
[
0.5488135
0.71518937
]
[
0.43758721
0.891773
]
[
0.79172504
0.52889492
]]
[[
0.4236548
0.64589411
]
[
0.56804456
0.92559664
]
[
0.0202184
0.83261985
]
[
0.07103606
0.0871293
]]
[[
0.77815675
0.87001215
]
[
0.96366276
0.38344152
]
[
0.97861834
0.79915856
]]
#最后一个batch样本个数为3
|
1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] 2、从buffer中取一个样本到batch中得: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] batch: [ 0.4236548 0.64589411] 3、shuffle buffer不足三个样本,从源数据集提取一个样本: shuffle buffer: [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.43758721 0.891773 ] 4、从buffer中取一个样本到batch中得: shuffle buffer: [ 0.5488135 0.71518937] [ 0.43758721 0.891773 ] batch: [ 0.4236548 0.64589411] [ 0.60276338 0.54488318] 5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
import
os
os.environ[
'CUDA_VISIBLE_DEVICES'
]
=
""
import
numpy as np
import
tensorflow as tf
np.random.seed(
0
)
x
=
np.random.sample((
11
,
2
))
# make a dataset from a numpy array
print
(x)
print
()
dataset
=
tf.data.Dataset.from_tensor_slices(x)
dataset
=
dataset.shuffle(
1
)
dataset
=
dataset.batch(
4
)
dataset
=
dataset.repeat(
2
)
# create the iterator
iter
=
dataset.make_one_shot_iterator()
el
=
iter
.get_next()
with tf.Session() as sess:
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
[[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]
[
0.43758721
0.891773
]
[
0.96366276
0.38344152
]
[
0.79172504
0.52889492
]
[
0.56804456
0.92559664
]
[
0.07103606
0.0871293
]
[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.97861834
0.79915856
]]
[[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]
[
0.43758721
0.891773
]]
[[
0.96366276
0.38344152
]
[
0.79172504
0.52889492
]
[
0.56804456
0.92559664
]
[
0.07103606
0.0871293
]]
[[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.97861834
0.79915856
]]
[[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]
[
0.43758721
0.891773
]]
[[
0.96366276
0.38344152
]
[
0.79172504
0.52889492
]
[
0.56804456
0.92559664
]
[
0.07103606
0.0871293
]]
[[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.97861834
0.79915856
]]
|
注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
import
os
os.environ[
'CUDA_VISIBLE_DEVICES'
]
=
""
import
numpy as np
import
tensorflow as tf
np.random.seed(
0
)
x
=
np.random.sample((
11
,
2
))
# make a dataset from a numpy array
print
(x)
print
()
dataset
=
tf.data.Dataset.from_tensor_slices(x)
dataset
=
dataset.repeat(
2
)
dataset
=
dataset.shuffle(
11
)
dataset
=
dataset.batch(
4
)
# create the iterator
iter
=
dataset.make_one_shot_iterator()
el
=
iter
.get_next()
with tf.Session() as sess:
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
print
(sess.run(el))
[[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]
[
0.43758721
0.891773
]
[
0.96366276
0.38344152
]
[
0.79172504
0.52889492
]
[
0.56804456
0.92559664
]
[
0.07103606
0.0871293
]
[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.97861834
0.79915856
]]
[[
0.56804456
0.92559664
]
[
0.5488135
0.71518937
]
[
0.60276338
0.54488318
]
[
0.07103606
0.0871293
]]
[[
0.96366276
0.38344152
]
[
0.43758721
0.891773
]
[
0.43758721
0.891773
]
[
0.77815675
0.87001215
]]
[[
0.79172504
0.52889492
]
#出现相同样本出现在同一个batch中
[
0.79172504
0.52889492
]
[
0.60276338
0.54488318
]
[
0.4236548
0.64589411
]]
[[
0.07103606
0.0871293
]
[
0.4236548
0.64589411
]
[
0.96366276
0.38344152
]
[
0.5488135
0.71518937
]]
[[
0.97861834
0.79915856
]
[
0.0202184
0.83261985
]
[
0.77815675
0.87001215
]
[
0.56804456
0.92559664
]]
[[
0.0202184
0.83261985
]
[
0.97861834
0.79915856
]]
#可以看到最后个batch为2,而前面都是4
|
使用案例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
def
input_fn(filenames, batch_size
=
32
, num_epochs
=
1
, perform_shuffle
=
False
):
print
(
'Parsing'
, filenames)
def
decode_libsvm(line):
#columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
#features = dict(zip(CSV_COLUMNS, columns))
#labels = features.pop(LABEL_COLUMN)
columns
=
tf.string_split([line],
' '
)
labels
=
tf.string_to_number(columns.values[
0
], out_type
=
tf.float32)
splits
=
tf.string_split(columns.values[
1
:],
':'
)
id_vals
=
tf.reshape(splits.values,splits.dense_shape)
feat_ids, feat_vals
=
tf.split(id_vals,num_or_size_splits
=
2
,axis
=
1
)
feat_ids
=
tf.string_to_number(feat_ids, out_type
=
tf.int32)
feat_vals
=
tf.string_to_number(feat_vals, out_type
=
tf.float32)
#feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
#for i in range(splits.dense_shape.eval()[0]):
# feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
# feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
#return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
return
{
"feat_ids"
: feat_ids,
"feat_vals"
: feat_vals}, labels
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset
=
tf.data.TextLineDataset(filenames).
map
(decode_libsvm, num_parallel_calls
=
10
).prefetch(
500000
)
# multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if
perform_shuffle:
dataset
=
dataset.shuffle(buffer_size
=
256
)
# epochs from blending together.
dataset
=
dataset.repeat(num_epochs)
dataset
=
dataset.batch(batch_size)
# Batch size to use
#return dataset.make_one_shot_iterator()
iterator
=
dataset.make_one_shot_iterator()
batch_features, batch_labels
=
iterator.get_next()
#return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
return
batch_features, batch_labels
|
到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。
原文链接:https://blog.csdn.net/qq_16234613/article/details/81703228 。
最后此篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就讲到这里了,如果你想了解更多关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我正在运行PHP脚本,并继续收到如下错误: 注意:未定义的变量:第10行的C:\ wamp \ www \ mypath \ index.php中的my_variable_name 注意
我正在运行PHP脚本,并继续收到如下错误: 注意:未定义的变量:第10行的C:\ wamp \ www \ mypath \ index.php中的my_variable_name 注意
我正在运行PHP脚本,并继续收到如下错误: 注意:未定义的变量:第10行的C:\ wamp \ www \ mypath \ index.php中的my_variable_name 注意
我正在运行一个PHP脚本,并且继续收到如下错误:。第10行和第11行如下所示:。这些错误消息的含义是什么?。为什么他们突然出现了?我多年来一直使用这个脚本,从来没有遇到过任何问题。。我该怎么修理它们呢
当我在 flutter clean 之后运行 flutter run 或 debug my code 时显示此错误 Note: C:\src\flutter.pub-cache\hosted\pub.
My Goal: To fix this error and be able to run my app without an error. Error Message: Note:D:\Learni
前言:今天在解决一个问题时,程序总是不能输出正确值,分析逻辑思路没问题后,发现原来是由于函数传递导致了这个情况。 LeetCode 113 问题:给你二叉树的根节点
我正在 R 中开发一个包,当我运行时 devtools::check()我收到以下说明。 checking DESCRIPTION meta-information ... NOTE Malforme
获得通知和警告波纹管 Notice: Use of undefined constant GLOB_BRACE - assumed 'GLOB_BRACE' in /var/www/html/open
我正在准备一个 R 包以提交给 CRAN。 R CMD 检查给了我以下注意: Foreign function calls to a different package: .Fortran("cinc
我正在尝试从以下位置获取数据: http://api.convoytrucking.net/api.php?api_key=public&show=player&player_name=Mick_Gi
我有这段代码,但我不明白为什么我仍然有这个错误,我已经尝试了所有解决方案,但无法解决这个问题:-注意:未定义索引:product_price-注意:未定义索引:product_quantity-注意:
This question already has answers here: “Notice: Undefined variable”, “Notice: Undefined index”, and
我正在尝试从以下位置获取数据: http://api.convoytrucking.net/api.php?api_key=public&show=player&player_name=Mick_Gi
切记,在PHP 7中不要做的10件事 1. 不要使用 mysql_ 函数 这一天终于来了,从此你不仅仅“不应该”使用mysql_函数。PHP 7 已经把它们从核心中全部移除了,也就是说你需要迁移
前几天安装了dedecms系统,当在后台安全退出的时候,后台出现空白,先前只分析其他功能去了,也没太注意安全,看了一下安全退出的代码,是这样写的: 复制代码 代码如下: function ex
我使用此代码来检查变量$n0、$n1、$n2是否未定义。 但每次未定义时我都会收到通知。我的代码是一种不好的做法吗?还有什么替代方案吗?或者只是删除通知,代码就可以了? if
编写代码时处理所有警告是否重要?在我公司中具有较高资历的开发人员坚持认为警告是无害的。诚然,其中一些是: Warning: Division by zero Notice: Undefined ind
我有一个搜索查询,执行搜索查询后,我将$ result放入数组中。 我的PHP代码- $contents = $client->search($params); // executing the se
This question already has answers here: “Notice: Undefined variable”, “Notice: Undefined index”, and
我是一名优秀的程序员,十分优秀!