gpt4 book ai didi

python - 如何将 tfrecord 数据读入张量/numpy 数组?

转载 作者:太空宇宙 更新时间:2023-11-04 07:56:34 26 4
gpt4 key购买 nike

我有一个 tfrecord 文件,我在其中存储了一个数据列表,每个元素都有 2d 坐标和 3d 坐标。坐标是 dtype float64 的 2d numpy 数组。

这些是我用来存储它们的特征。

feature = {'train/coord2d': _floats_feature(projC),
'train/coord3d': _floats_feature(sChair)}

我通过将它们展平成一个 float 列表来存储它们。

def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))

现在我正在尝试恢复它们,以便将它们输入到我的网络中进行训练。我希望 2d 坐标作为输入,3d 作为输出来训练我的网络。

def read_and_decode(filename):
filename_queue = tf.train.string_input_producer(filename, name='queue')
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,

features= {'train/coord2d': tf.FixedLenFeature([], tf.float32),
'train/coord3d': tf.FixedLenFeature([], tf.float32)})

coord2d = tf.cast(features['train/coord2d'], tf.float32)
coord3d = tf.cast(features['train/coord3d'], tf.float32)

return coord2d, coord3d



with tf.Session() as sess:
filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename)
c2d, c3d = read_and_decode(filename)
print(sess.run(c2d))
print(sess.run(c3d))

这是我的代码,但我并不真正理解它,因为我是从教程等那里得到的,所以我试图打印出 c2d 和 c3d 以查看它们的格式,但我的程序一直在运行,根本没有打印任何东西,从来没有终止。 c2d 和 c3d 是否包含数据集中每个元素的 2d 和 3d 坐标?在训练网络时可以直接使用它们作为输入和输出吗?

我也不知道它们应该是什么格式才能用作网络的输入。我应该将它们转换回 2d numpy 数组还是 2d 张量?在这种情况下我该怎么做?总的来说,我只是非常迷茫,所以任何指导都会非常有帮助!谢谢

最佳答案

您在 tf.data.TFRecordDataset(filename) 的正确行上,但问题是 dataset 未连接到您传递的张量到 sess.run()

这是一个应该产生一些输出的简单示例程序:

def decode(serialized_example):
# NOTE: You might get an error here, because it seems unlikely that the features
# called 'coord2d' and 'coord3d', and produced using `ndarray.flatten()`, will
# have a scalar shape. You might need to change the shape passed to
# `tf.FixedLenFeature()`.
features = tf.parse_single_example(
serialized_example,
features={'train/coord2d': tf.FixedLenFeature([], tf.float32),
'train/coord3d': tf.FixedLenFeature([], tf.float32)})

# NOTE: No need to cast these features, as they are already `tf.float32` values.
return features['train/coord2d'], features['train/coord3d']

filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename).map(decode)
iterator = dataset.make_one_shot_iterator()
c2d, c3d = iterator.get_next()

with tf.Session() as sess:

try:

while True:
print(sess.run((c2d, c3d)))

except tf.errors.OutOfRangeError:
# Raised when we reach the end of the file.
pass

关于python - 如何将 tfrecord 数据读入张量/numpy 数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47878207/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com