gpt4 book ai didi

python - 在 tensorflow 中运行 Iterator.get_next() 后如何获取前一批?

转载 作者:行者123 更新时间:2023-11-28 18:17:24 25 4
gpt4 key购买 nike

最近想实现GAN模型,使用tf.Dataset和Iterator读取人脸图片作为训练数据。

数据集和迭代器对象的代码是:

self.dataset = tf.data.Dataset.from_tensor_slices(convert_to_tensor(self.data_ob.train_data_list, dtype=tf.string))
self.dataset = self.dataset.map(self._parse_function)
#self.dataset = self.dataset.shuffle(buffer_size=10000)
self.dataset = self.dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

self.iterator = tf.data.Iterator.from_structure(self.dataset.output_types, self.dataset.output_shapes)
self.next_x = self.iterator.get_next()

我的新 GAN 模型是:

self.z_mean, self.z_sigm = self.Encode(self.next_x)
self.z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigm))*self.ep)
self.x_tilde = self.generate(self.z_x, reuse=False)
#the feature
self.l_x_tilde, self.De_pro_tilde = self.discriminate(self.x_tilde)

#for Gan generator
self.x_p = self.generate(self.zp, reuse=True)
# the loss of dis network
self.l_x, self.D_pro_logits = self.discriminate(self.next_x, True)

所以,问题是我两次使用 self.next_x 作为输入张量。每次的数据集都是不同的。那么,如何解决这个问题才能保留首批处理以供重复使用呢?

最佳答案

我在代码中使用的是以下内容,其中 xy_true 是占位符。不确定是否有更有效的实现。

images, labels = session.run(next_element)
batch_accuracy = session.run(accuracy, feed_dict={x: images, y_true: labels, keep_prop: 1.0})
batch_predicted_probabilities = session.run(y_pred, feed_dict={x: images, y_true: labels, keep_prop: 1.0})

我目前正在尝试使用 tf.placeholder_with_default 而不是 x 和 y_true 的普通占位符来检查它是否能在我的项目中提供更好的性能。如果我很快就能得到任何结果,我会编辑我的答案让你知道:)。

编辑:我切换到 placeholder_with_default 并且它没有给每批处理带来明显的速度提升,至少在我测量它的方式上是这样。

关于python - 在 tensorflow 中运行 Iterator.get_next() 后如何获取前一批?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47485023/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com