gpt4 book ai didi

python - 停止 TensorFlow 数据集 `from_generator` 的正确方法?

转载 作者:太空宇宙 更新时间:2023-11-04 00:17:49 25 4
gpt4 key购买 nike

我想使用通过 from_generator 构建的 TensorFlow 数据集来访问格式化文件。大多数一切正常,除了我不知道如何在生成器用完数据时停止数据集迭代器(当您超出范围时,生成器只会永远返回空列表)。

我的实际代码非常复杂,但我可以用这个短程序来模拟情况:

import tensorflow as tf

def make_batch_generator_fn(batch_size=10, dset_size=100):
feats, targs = range(dset_size), range(1, dset_size + 1)

def batch_generator_fn():
start_idx, stop_idx = 0, batch_size
while True:
# if stop_idx > dset_size: --- stop action?
yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

return batch_generator_fn

def test(batch_size=10):
dgen = make_batch_generator_fn(batch_size)
features_shape, targets_shape = [None], [None]
ds = tf.data.Dataset.from_generator(
dgen, (tf.int32, tf.int32),
(tf.TensorShape(features_shape), tf.TensorShape(targets_shape))
)
feats, targs = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
counter = 0
try:
while True:
f, t = sess.run([feats, targs])
print(f, t)
counter += 1
if counter > 15:
break
except tf.errors.OutOfRangeError:
print('end of dataset at counter = {}'.format(counter))

if __name__ == '__main__':
test()

如果我提前知道记录数,我可以调整批处理数,但我并不总是知道。我尝试将一些代码放在上面的代码片段中,其中我有一个注释行,如 stop action?。特别是,我尝试引发 IndexError,但 TensorFlow 不喜欢这样,即使我在我的执行代码中显式 catch 它也是如此。我还尝试引发 tf.errors.OutOfRangeError,但我不确定如何实例化它:构造函数需要三个参数 - 'node_def'、'op' 和 'message',而我我不太确定一般情况下“node_def”和“op”要使用什么。

如果您对此问题有任何想法或意见,我将不胜感激。谢谢!

最佳答案

满足停止条件时返回:

def make_batch_generator_fn(batch_size=10, dset_size=100):
feats, targs = range(dset_size), range(1, dset_size + 1)

def batch_generator_fn():
start_idx, stop_idx = 0, batch_size
while True:
if stop_idx > dset_size:
return
else:
yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

return batch_generator_fn

这符合 Python 3 documentation: 中指定的行为

In a generator function, the return statement indicates that the generator is done and will cause StopIteration to be raised. The returned value (if any) is used as an argument to construct StopIteration and becomes the StopIteration.value attribute.

关于python - 停止 TensorFlow 数据集 `from_generator` 的正确方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50275833/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com