gpt4 book ai didi

python - Tensorflow tfrecords 不允许设置形状?

转载 作者:行者123 更新时间:2023-11-30 09:09:01 24 4
gpt4 key购买 nike

情况

我正在尝试将图像数据存储在 tfrecords 中。

详细信息

图像具有形状 (256,256,4) 和标签 (17)。看来tfrecords保存正确(高度和宽度属性可以成功解码)

问题

当我测试使用 session 从 tfrecords 中提取图像和标签时,会引发错误。标签形状似乎有些不对劲

错误消息

INFO:tensorflow:Error reported to Coordinator: 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, Input to >reshape is a tensor with 34 values, but the requested shape has 17 [[Node: Reshape_4 = Reshape[T=DT_INT32, Tshape=DT_INT32, >_device="/job:localhost/replica:0/task:0/cpu:0"](DecodeRaw_5, >Reshape_4/shape)]]

代码

注意:我对第一部分非常有信心,因为它是直接从 tensorflow 文档示例中复制的

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

"""Converts a dataset to tfrecords."""
# Open files
train_filename = os.path.join('./data/train.tfrecords')
validation_filename = os.path.join('./data/validation.tfrecords')

# Create writers
train_writer = tf.python_io.TFRecordWriter(train_filename)
# validation_writer = tf.python_io.TFRecordWriter(validation_filename)

for i in range(200):
label = y[i]
img = io.imread(TRAINING_IMAGES_DIR + '/train_' + str(i) + '.tif')

example = tf.train.Example(features=tf.train.Features(feature={
'width': _int64_feature([img.shape[0]]),
'height': _int64_feature([img.shape[1]]),
'channels': _int64_feature([img.shape[2]]),
'label': _bytes_feature(label.tostring()),
'image': _bytes_feature(img.tostring())
}))

# if i in validation_indices:
# validation_writer.write(example.SerializeToString())
# else:
train_writer.write(example.SerializeToString())

train_writer.close()
# validation_writer.close()

错误部分。请注意,特别奇怪的是,如果我将 reshape 函数更改为 [34],我仍然会遇到相同的错误。

data_path = './data/train.tfrecords'

with tf.Session() as sess:
feature = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)}

# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)

# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)

# Convert the image data from string back to the numbers
image = tf.decode_raw(features['image'], tf.float32)

# Cast label data into int32
label = tf.decode_raw(features['label'], tf.int8)

# Reshape image data into the original shape
image = tf.reshape(image, [256, 256, 4])
label = tf.reshape(label, [17])

# Any preprocessing here ...

# Creates batches by randomly shuffling tensors
images, labels = tf.train.shuffle_batch([image, label], batch_size=1, capacity=20, num_threads=1, min_after_dequeue=10)

# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)

# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

img, lbl = sess.run([images, labels])
img

# Stop the threads
coord.request_stop()

# Wait for threads to stop
coord.join(threads)

sess.close()

最佳答案

当您的标签在以字节形式保存在 tfrecords 中之前为 tf.int16 时,可能会出现此问题。因此,当您阅读 tf.int8 时,它的数字是您预期的两倍。因此,您可以通过 tfrecords 转换代码中的 label = tf.cast(y[i], tf.int8) 来确保标签正确写入。

关于python - Tensorflow tfrecords 不允许设置形状?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45048081/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com