gpt4 book ai didi

python - Tensorflow Dataset API 将图形 protobuff 文件大小加倍

转载 作者:太空宇宙 更新时间:2023-11-03 10:53:34 28 4
gpt4 key购买 nike

总结:使用新的 tf.contrib.data.Dataset 使我的图形 protobuff 文件的大小加倍,我无法在 Tensorboard 中可视化图形。

详情:

我正在试用新的 TensorFlow tf.contrib.data.Dataset功能与 tf.contrib.learn.Experiment 一起框架。我的输入数据定义为 input functions它返回特征和标签的张量。

如果我用 tf.train.slice_input_producer 创建我的输入函数功能类似于以下代码块(完整代码 here),那么我生成的 graph.pbtxt 文件大小为 620M,.meta 文件大小约为 165M。

def train_inputs():
with tf.name_scope('Training_data'):
x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1]))
y = tf.constant(mnist.train.labels)
sliced_input = tf.train.slice_input_producer(
tensor_list=[x, y], shuffle=True)
return tf.train.shuffle_batch(
sliced_input, batch_size=batch_size,
capacity=10000, min_after_dequeue=batch_size*10)

现在,如果我使用新的 tf.contrib.data.Dataset.from_tensor_slices 创建我的输入函数就像在下面的代码块(完整代码 here )中一样,然后我生成的 graph.pbtxt 文件的大小加倍到 1.3G,.meta 文件的大小加倍到 330M .

def train_inputs():
with tf.name_scope('Training_data'):
images = mnist.train.images.reshape([-1, 28, 28, 1])
labels = mnist.train.labels
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images, labels))
dataset = dataset.repeat(None) # Infinite
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
return next_example, next_label

现在因为 graph.pbtxt 文件太大,TensorBoard 需要很长时间来解析这个文件,我无法直观地调试我的模型图。我在 Dataset documentation 中找到这种大小的增加来自:“数组的内容将被复制多次”solution将是使用占位符。但是,在这种情况下,我需要将 numpy 数组输入到具有事件 session 的占位符中以初始化迭代器:

sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels})

然而,在使用 tf.contrib.learn.Experiment 框架时,这似乎超出了我的控制范围。

如何使用 Experiment 框架初始化迭代器的初始化程序?或者在不增加图形大小的情况下找到使用数据集 API 的解决方法?

最佳答案

我使用 tf.train.SessionRunHook 找到了解决我的问题的方法.我创建了一个 SessionRunHook 对象,它在创建 session 后初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initiliser_func = None

def after_create_session(self, session, coord):
self.iterator_initiliser_func(session)

初始化函数在创建数据集迭代器时设置:

iterator_initiliser_hook.iterator_initiliser_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})

我将 Hook 对象传递给 tf.contrib.learn.Experimenttrain_monitorseval_hooks 参数。

生成的 graph.pbtxt 文件现在只有 500K,而 .meta 文件只有 244K。

Full example here.

关于python - Tensorflow Dataset API 将图形 protobuff 文件大小加倍,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45549251/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com