gpt4 book ai didi

python - 确定 tf.data.Dataset Tensorflow 中的记录数

转载 作者:太空宇宙 更新时间:2023-11-03 11:59:53 25 4
gpt4 key购买 nike

<分区>

我想将数据集迭代器传递给函数,但该函数需要知道数据集的长度。在下面的示例中,我可以将 len(datafiles) 传递给 my_custom_fn() 函数,但我想知道我是否能够从中提取数据集的长度iteratorbatch_xbatch_y 类,这样我就不必将其添加为输入。

dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

谢谢!

编辑:此解决方案不适用于我的情况:tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

运行后

tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])

返回

<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>

我确实找到了适合我的解决方案。将 iterator_scope 添加到我的代码中,例如:

with tf.name_scope('iter'):
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

然后从 my_custom_fn 调用:

def my_custom_fn(batch_x, batch_y):
filenames = batch_x.graph.get_operation_by_name(
'iter/InputDataSet/filenames').outputs[0]
n_epoch = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/count').outputs)[0]
batch_size = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/batch_size').outputs)[0]
# lots of other stuff

不确定这是否是最好的方法,但它似乎有效。很高兴就此提出任何建议,因为它看起来有点老套。

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com