gpt4 book ai didi

python - 使用估算器 api 避免 tf.data.Dataset.from_tensor_slices

转载 作者:太空宇宙 更新时间:2023-11-03 15:41:38 25 4
gpt4 key购买 nike

我正在尝试找出将 dataset api 与 estimator api 一起使用的推荐方法。我在网上看到的一切都是这个的一些变体:

def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset

然后可以将其传递给估算器的训练函数:

 classifier.train(
input_fn=train_input_fn,
#...
)

但是dataset guide警告:

the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

然后描述了一种方法,该方法涉及定义占位符,然后用 feed_dict 填充:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})

但如果您使用的是 estimator api,则您并不是在手动运行 session 。那么,如何将 dataset api 与估算器一起使用,同时避免与 from_tensor_slices() 相关的问题?

最佳答案

要使用可初始化或可重新初始化的迭代器,您必须创建一个继承自 tf.train.SessionRunHook 的类,它可以在训练和评估步骤中多次访问 session 。

然后您可以使用这个新类来初始化迭代器,就像您通常在经典设置中所做的那样。您只需将这个新创建的钩子(Hook)传递给训练/评估函数或传递给正确的训练规范。

这是您可以根据需要进行调整的简单示例:

class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn

def after_create_session(self, session, coord):
# Initialize the iterator with the data feed_dict
self.iterator_initializer_func(session)


def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()

def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)

dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...

iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()


iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})

return next_example, next_label

return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])

关于python - 使用估算器 api 避免 tf.data.Dataset.from_tensor_slices,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52266000/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com