gpt4 book ai didi

tensorflow - 如何使用 tf.data API 读取(解码)tfrecords

转载 作者:行者123 更新时间:2023-12-04 01:52:50 30 4
gpt4 key购买 nike

我有一个自定义数据集,然后我将其存储为 tfrecord,做

# toy example data
label = np.asarray([[1,2,3],
[4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def labeled_image_to_tfexample(sample_binary_string, label_binary_string):
return tf.train.Example(features=tf.train.Features(feature={
'sample/image': bytes_feature(sample_binary_string),
'sample/label': bytes_feature(label_binary_string)
}))


def _write_to_tf_record():
with tf.Graph().as_default():
image_placeholder = tf.placeholder(dtype=tf.uint16)
encoded_image = tf.image.encode_png(image_placeholder)

label_placeholder = tf.placeholder(dtype=tf.uint16)
encoded_label = tf.image.encode_png(image_placeholder)

with tf.python_io.TFRecordWriter("./toy.tfrecord") as writer:
with tf.Session() as sess:
feed_dict = {image_placeholder: sample,
label_placeholder: label}

# Encode image and label as binary strings to be written to tf_record
image_string, label_string = sess.run(fetches=(encoded_image, encoded_label),
feed_dict=feed_dict)

# Define structure of what is going to be written
file_structure = labeled_image_to_tfexample(image_string, label_string)

writer.write(file_structure.SerializeToString())
return

但是我无法阅读它。首先我尝试过(基于 http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.htmlhttps://medium.com/coinmonks/storage-efficient-tfrecord-for-images-6dc322b81db4https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564 )
def read_tfrecord_low_level():
data_path = "./toy.tfrecord"
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
reader = tf.TFRecordReader()
_, raw_records = reader.read(filename_queue)

decode_protocol = {
'sample/image': tf.FixedLenFeature((), tf.int64),
'sample/label': tf.FixedLenFeature((), tf.int64)
}
enc_example = tf.parse_single_example(raw_records, features=decode_protocol)
recovered_image = enc_example["sample/image"]
recovered_label = enc_example["sample/label"]

return recovered_image, recovered_label

我还尝试了转换 enc_example 并对其进行解码的变体,例如在 Unable to read from Tensorflow tfrecord file 中但是,当我尝试评估它们时,我的 python session 只是卡住并且不提供任何输出或回溯。

然后我尝试使用 Eager Execution 来查看发生了什么,但显然它只与 tf.data API 兼容。然而,据我所知,对 tf.data API 的转换是在整个数据集上进行的。 https://www.tensorflow.org/api_guides/python/reading_data提到必须编写解码函数,但没有给出如何做到这一点的例子。我找到的所有教程都是为 TFRecordReader 制作的(这对我不起作用)。

非常感谢任何帮助(指出我做错了什么/解释正在发生的事情/关于如何使用 tf.data API 解码 tfrecords 的指示)。

根据 https://www.youtube.com/watch?v=4oNdaQk0Qv4https://www.youtube.com/watch?v=uIcqeP7MFH0 tf.data 是创建输入管道的最佳方式,所以我对以这种方式学习非常感兴趣。

提前致谢!

最佳答案

我不确定为什么存储编码的 png 会导致评估不起作用,但这是解决该问题的可能方法。既然你提到你想使用 tf.data创建输入管道的方法,我将展示如何在您的玩具示例中使用它:

label = np.asarray([[1,2,3],
[4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

首先,必须将数据保存到 TFRecord 文件中。与您所做的不同之处在于图像未编码为 png。
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.python_io.TFRecordWriter("toy.tfrecord")

example = tf.train.Example(features=tf.train.Features(feature={
'label_raw': _bytes_feature(tf.compat.as_bytes(label.tostring())),
'sample_raw': _bytes_feature(tf.compat.as_bytes(sample.tostring()))}))

writer.write(example.SerializeToString())

writer.close()

上面代码中发生的事情是将数组转换为字符串(1d 对象),然后存储为字节特征。

然后,使用 tf.data.TFRecordDataset 读回数据和 tf.data.Iterator类(class):
filename = 'toy.tfrecord'

# Create a placeholder that will contain the name of the TFRecord file to use
data_path = tf.placeholder(dtype=tf.string, name="tfrecord_file")

# Create the dataset from the TFRecord file
dataset = tf.data.TFRecordDataset(data_path)

# Use the map function to read every sample from the TFRecord file (_read_from_tfrecord is shown below)
dataset = dataset.map(_read_from_tfrecord)

# Create an iterator object that enables you to access all the samples in the dataset
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
label_tf, sample_tf = iterator.get_next()

# Similarly to tf.Variables, the iterators have to be initialised
iterator_init = iterator.make_initializer(dataset, name="dataset_init")

with tf.Session() as sess:
# Initialise the iterator passing the name of the TFRecord file to the placeholder
sess.run(iterator_init, feed_dict={data_path: filename})

# Obtain the images and labels back
read_label, read_sample = sess.run([label_tf, sample_tf])

函数 _read_from_tfrecord()是:
def _read_from_tfrecord(example_proto):
feature = {
'label_raw': tf.FixedLenFeature([], tf.string),
'sample_raw': tf.FixedLenFeature([], tf.string)
}

features = tf.parse_example([example_proto], features=feature)

# Since the arrays were stored as strings, they are now 1d
label_1d = tf.decode_raw(features['label_raw'], tf.int64)
sample_1d = tf.decode_raw(features['sample_raw'], tf.int64)

# In order to make the arrays in their original shape, they have to be reshaped.
label_restored = tf.reshape(label_1d, tf.stack([2, 3, -1]))
sample_restored = tf.reshape(sample_1d, tf.stack([2, 3, -1]))

return label_restored, sample_restored

而不是硬编码形状 [2, 3, -1] ,您也可以将其存储到 TFRecord 文件中,但为简单起见,我没有这样做。

我做了一点 gist有一个工作示例。

希望这可以帮助!

关于tensorflow - 如何使用 tf.data API 读取(解码)tfrecords,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52099863/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com