gpt4 book ai didi

python - 使用 Tensorflow 进行文本输入

转载 作者:太空宇宙 更新时间:2023-11-03 14:04:33 25 4
gpt4 key购买 nike

我正在使用 Tensorflow 并尝试构建 RNN 语言模型。我正在努力解决如何读取原始文本输入文件的问题。

Tensorflow guide提到了一些方法,包括:

  1. tf.data.Dataset.from_tensor_slices() - 假设我的数据在内存中可用(np.array?)
  2. tf.data.TFRecordDataset(不知道如何使用它)
  3. tf.data.TextLineDataset(与2有什么区别?API页面几乎相同)

对2和3感到困惑,我只能尝试方法1,但面临以下问题:

  1. 如果我的数据太大而无法装入内存怎么办?
  2. TF 需要固定长度、填充格式,我该怎么做? -我是否:确定一个固定长度值(例如 30),将每一行读入列表,如果列表较长,则将列表截断为 30然后30,填充 '0' 以使每行至少 30 长,将列表附加到 numpy 数组/矩阵?

我确信这些都是常见问题,tensorflow 很多都提供了内置函数!

最佳答案

如果您的数据位于文本文件(csv、tsv 或只是行的集合)中,最好的方法是使用 tf.data.TextLineDataset 来处理它。 ; tf.data.TFRecordDataset有类似的 API,但它适用于 TFRecord 二进制格式(如果您需要一些详细信息,请查看 this nice post)。

通过数据集 API 处理文本行的一个很好的例子是 TensorFlow Wide & Deep Learning Tutorial (代码是 here )。这是此处使用的输入函数:

def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)

def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')

# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)

if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

dataset = dataset.map(parse_csv, num_parallel_calls=5)

# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels

以下是此代码片段中发生的事情:

  • tf.data.TextLineDataset(data_file) 行创建 Dataset对象,分配给数据集。它是一个包装器,而不是内容持有者,因此数据永远不会完全读入内存。

  • Dataset API 允许预处理数据,例如有shufflemapbatch等方法。请注意,API 是函数式的,这意味着当您调用 Dataset 方法时不会处理任何数据,它们只是定义当 session 实际启动并评估迭代器时将使用张量执行哪些转换(见下文)。

  • 最后,dataset.make_one_shot_iterator() 返回一个迭代器张量,可以从中读取值。您可以评估featureslabels,它们将在转换后获得数据批处理的值。

  • 另请注意,如果您在 GPU 上训练模型,数据将直接流式传输到设备,而无需在客户端(Python 脚本本身)中进行中间停止。

根据您的特定格式,您可能不需要解析 csv 列,只需逐行读取即可。

<小时/>

建议阅读:Importing Data指南。

关于python - 使用 Tensorflow 进行文本输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48998848/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com