gpt4 book ai didi

python - 如何使用 DataSet API 在 Tensorflow 中为 tf.train.SequenceExample 数据创建填充批处理?

转载 作者:太空狗 更新时间:2023-10-29 17:47:57 25 4
gpt4 key购买 nike

为了在 Tensorflow 中训练 LSTM 模型,我将我的数据结构化为 tf.train.SequenceExample 格式并将其存储到TFRecord 文件。我现在想使用新的 DataSet API 来生成用于训练的填充批处理。在 the documentation有一个使用 padded_batch 的示例,但对于我的数据,我无法弄清楚 padded_shapes 的值应该是多少。

为了将 TFrecord 文件读取到批处理中,我编写了以下 Python 代码:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
print "Usage: createbatches.py [RFRecord file]"
sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}

_, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

length = tf.shape(sequence['inputs'])[0]
return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()

batch = iterator.get_next()

# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

print(sess.run(batch))

如果我使用 dataset = dataset.batch(1)(在这种情况下不需要填充),代码运行良好,但是当我使用 padded_batch 变体时,我得到以下错误:

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .

你能帮我弄清楚我应该为 padded_shapes 参数传递什么吗?

(我知道有很多使用线程和队列的示例代码,但我更愿意为这个项目使用新的 DataSet API)

最佳答案

您需要传递一个形状元组。在你的情况下你应该通过

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))

或者试试

dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))

检查这个code更多细节。我不得不调试此方法以弄清楚为什么它对我不起作用。

关于python - 如何使用 DataSet API 在 Tensorflow 中为 tf.train.SequenceExample 数据创建填充批处理?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45955241/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com