gpt4 book ai didi

python - TensorFlow 整个数据集存储在图中

转载 作者:行者123 更新时间:2023-12-01 09:31:46 25 4
gpt4 key购买 nike

我正在致力于使用 Cifar-10 数据集开发 CNN,并将数据提供给网络,我正在使用数据集 API 来使用带有句柄占位符的可输入迭代器:https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator 。就我个人而言,我真的很喜欢这种方法,因为它提供了一种清晰而简单的方法来将数据馈送到网络并在测试集和验证集之间切换。但是,当我在训练结束时保存图表时,创建的 .meta 文件与我开始时的测试数据一样大。我使用这些操作来提供对输入占位符和输出运算符的访问:

tf.get_collection("validation_nodes")
tf.add_to_collection("validation_nodes", input_data)
tf.add_to_collection("validation_nodes", input_labels)
tf.add_to_collection("validation_nodes", predict)

然后使用以下命令保存图表:训练前:

saver = tf.train.Saver()

训练后:

save_path = saver.save(sess, "./my_model")

有没有办法阻止 TensorFlow 存储图中的所有数据?提前致谢!

最佳答案

您正在为数据集创建一个tf.constant,这就是将其添加到图形定义中的原因。解决方案是使用可初始化迭代器并定义占位符。在开始对图表运行操作之前要做的第一件事就是向其提供数据集。有关示例,请参阅“创建迭代器”部分下的程序员指南。

https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator

我做的完全一样,所以这里是我用来准确实现您的描述的代码相关部分的复制/粘贴(使用可初始化迭代器训练/测试 cifar10 集):

  def build_datasets(self):
""" Creates a train_iterator and test_iterator from the two datasets. """
self.imgs_4d_uint8_placeholder = tf.placeholder(tf.uint8, [None, 32, 32, 3], 'load_images_placeholder')
self.imgs_4d_float32_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3], 'load_images_float32_placeholder')
self.labels_1d_uint8_placeholder = tf.placeholder(tf.uint8, [None], 'load_labels_placeholder')
self.load_data_train = tf.data.Dataset.from_tensor_slices({
'data': self.imgs_4d_uint8_placeholder,
'labels': self.labels_1d_uint8_placeholder
})
self.load_data_test = tf.data.Dataset.from_tensor_slices({
'data': self.imgs_4d_uint8_placeholder,
'labels': self.labels_1d_uint8_placeholder
})
self.load_data_adversarial = tf.data.Dataset.from_tensor_slices({
'data': self.imgs_4d_float32_placeholder,
'labels': self.labels_1d_uint8_placeholder
})

# Train dataset pipeline
dataset_train = self.load_data_train
dataset_train = dataset_train.shuffle(buffer_size=50000)
dataset_train = dataset_train.repeat()
dataset_train = dataset_train.map(self._img_augmentation, num_parallel_calls=8)
dataset_train = dataset_train.map(self._img_preprocessing, num_parallel_calls=8)
dataset_train = dataset_train.batch(self.hyperparams['batch_size'])
dataset_train = dataset_train.prefetch(2)
self.iterator_train = dataset_train.make_initializable_iterator()

# Test dataset pipeline
dataset_test = self.load_data_test
dataset_test = dataset_test.map(self._img_preprocessing, num_parallel_calls=8)
dataset_test = dataset_test.batch(self.hyperparams['batch_size'])
self.iterator_test = dataset_test.make_initializable_iterator()



def init(self, sess):
self.cifar10 = Cifar10() # a class I wrote for loading cifar10
self.handle_train = sess.run(self.iterator_train.string_handle())
self.handle_test = sess.run(self.iterator_test.string_handle())
sess.run(self.iterator_train.initializer, feed_dict={self.handle: self.handle_train,
self.imgs_4d_uint8_placeholder: self.cifar10.train_data,
self.labels_1d_uint8_placeholder: self.cifar10.train_labels})

关于python - TensorFlow 整个数据集存储在图中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49903653/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com