gpt4 book ai didi

python - 从 TFRecordDataset 获取数据集作为 numpy 数组

转载 作者:太空狗 更新时间:2023-10-30 01:11:39 25 4
gpt4 key购买 nike

我正在使用新的 tf.data API 为 CIFAR10 数据集创建迭代器。我正在从两个 .tfrecord 文件中读取数据。一个保存训练数据 (train.tfrecords),另一个保存测试数据 (test.tfrecords)。这一切都很好。然而,在某些时候,我需要两个数据集(训练数据和测试数据)作为 numpy 数组

是否可以从 tf.data.TFRecordDataset 对象中检索数据集作为 numpy 数组?

最佳答案

您可以使用 tf.data.Dataset.batch()转型与tf.contrib.data.get_single_element()去做这个。作为回顾,dataset.batch(n)将占用 datasetn 个连续元素,并通过连接每个组件将它们转换为一个元素。这要求所有元素的每个组件都具有固定的形状。如果 n 大于 dataset 中的元素个数(或者如果 n 没有整除元素个数),那么最后一个批处理可以更小。因此,您可以为 n 选择一个较大的值并执行以下操作:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...) # A dataset of tf.string records.
dataset = dataset.map(...) # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
whole_dataset_arrays = sess.run(whole_dataset_tensors)

关于python - 从 TFRecordDataset 获取数据集作为 numpy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48871438/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com