gpt4 book ai didi

python - 如何在 'make_initializable_iterator' 中使用 tensorflow 的迭代器 'input_fn'?

转载 作者:太空宇宙 更新时间:2023-11-04 07:56:08 27 4
gpt4 key购买 nike

我想用 tf.estimator.Estimator 训练我的模式并通过数据集 API 加载我的数据。因为我的数据,例如“mnist”,是一个数组(张量),所以我尝试用“tf”加载它.data.Dataset.from_tensor_slices'。但我不知道如何在“input_fn”中初始化“make_initializable_iterator”。

如果我可以使用“make_one_shot_iterator”来成功训练,但它在训练前加载缓慢。和《 Higher-Level APIs in TensorFlow 》是在'input_fn'中'make_initializable_iterator'的一个很好的例子,但它需要从'input_fn'返回一个'iterator_initializer_hook'给其他函数。我想知道还有其他更好或更优雅的方式吗?

    def input_fn():

mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
iterator = dataset.make_one_shot_iterator()
next_example = iterator.get_next()
# Set runhook to initialize iterator

return next_example

最佳答案

在 TensorFlow 1.5 及更高版本中,tf.estimator.Estimator 将在您从 input_fn。这使您能够编写以下代码,而不必担心初始化或 Hook :

def input_fn():
mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

# Build dataset.
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
return dataset

关于python - 如何在 'make_initializable_iterator' 中使用 tensorflow 的迭代器 'input_fn'?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48614529/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com