gpt4 book ai didi

python - 如何使用 tf.data 可初始化迭代器和可初始化迭代器并将数据提供给估算器 api?

转载 作者:太空狗 更新时间:2023-10-30 01:04:28 24 4
gpt4 key购买 nike

所有官方 google 教程都对所有估算器 api 实现使用一次性迭代器,我找不到任何关于如何使用 tf.data 的可初始化迭代器和可重新初始化迭代器而不是一次性迭代器的文档。

有人可以告诉我如何使用 tf.data 的可初始化迭代器和可重新初始化的迭代器在 train_data 和 test_data 之间切换。我们需要运行一个 session 来使用 feed dict 并在可初始化迭代器中切换数据集,它是一个低级 api 并且它令人困惑如何使用它作为 estimator api 架构的一部分

PS:我确实发现谷歌提到“注意:目前,一次性迭代器是唯一可以轻松用于 Estimator 的类型。”

但是社区内有什么解决办法吗?还是我们应该出于某种充分的理由坚持使用一次性迭代器

最佳答案

要使用可初始化或可重新初始化的迭代器,您必须创建一个继承自 tf.train.SessionRunHook 的类。然后,此类可以访问 tf.estimator 函数使用的 session 。

这是您可以根据需要进行调整的简单示例:

class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn

def after_create_session(self, session, coord):
self.iterator_initializer_func(session)


def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()

def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)

dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...

iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()


iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})

return next_example, next_label

return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])

这是我在 blogpost 中找到的代码的修改版本通过 Sebastian Pölsterl .查看“通过数据集 API 将数据提供给估算器”部分。

关于python - 如何使用 tf.data 可初始化迭代器和可初始化迭代器并将数据提供给估算器 api?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51625529/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com