gpt4 book ai didi

python - Tensorflow - 来自 tf.train.shuffle_batch 的下一批数据

转载 作者:太空宇宙 更新时间:2023-11-03 15:43:08 26 4
gpt4 key购买 nike

我有一个 tfrecords 文件,我希望从中创建批量数据。我正在使用 tf.train.shuffle_batch() 来创建单个批处理。在我的训练中,我想调用批处理并通过它们。这就是我被困住的地方。我读到,TFRecordReader() 的位置被保存在图形的状态中,并且从后续位置读取下一个示例。问题是我不知道如何加载下一批。我使用下面的代码来创建批处理。

def read_and_decode_single_example(filename):
filename_queue = tf.train.string_input_producer([filename], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'context': tf.FixedLenFeature([160], tf.int64),
'context_len': tf.FixedLenFeature([1], tf.int64),
'utterance': tf.FixedLenFeature([160], tf.int64),
'utterance_len': tf.FixedLenFeature([1], tf.int64),
'label': tf.FixedLenFeature([1], tf.int64)
})

contexts = features['context']
context_lens = features['context_len']
utterances = features['utterance']
utterance_lens = features['utterance_len']
labels = features['label']

return contexts, context_lens, utterances, utterance_lens, labels

contexts, context_lens, utterances, utterance_lens, labels = \
read_and_decode_single_example('data/train.tfrecords')

contexts_batch, context_lens_batch, \
utterances_batch, utterance_lens_batch, \
labels_batch = tf.train.shuffle_batch([contexts, context_lens, utterances,
utterance_lens, labels],
batch_size=batch_size,
capacity=3*batch_size,
min_after_dequeue=batch_size)

这给了我一批数据。我想使用 feed_dict 范例来传递训练批处理,其中每次迭代时都会传入一个新批处理。如何加载这些批处理?调用 read_and_decodetf.train.shuffle_batch 是否会再次调用下一个批处理?

最佳答案

read_and_decode_single_example() 函数为用于加载数据的网络创建一个(子)图;你只调用一次。它可能更合适地称为 build_read_and_decode_single_example_graph(),但这有点长。

“魔力”在于多次评估(即使用)_batch 张量,例如

batch_size = 100
# ...

with tf.Session() as sess:
# get the first batch of 100 values
first_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])

# second batch of different 100 values
second_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])
# etc.

当然,您可以将它们输入到网络的其他部分,而不是手动从 session 中获取这些值。机制是相同的:每当直接或间接获取这些张量之一时,批处理机制将负责每次为您提供一个新批处理(不同值)。

关于python - Tensorflow - 来自 tf.train.shuffle_batch 的下一批数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41978221/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com