gpt4 book ai didi

tensorflow - tf.decode_raw 和 tf.reshape 使用不同的图像大小

转载 作者:行者123 更新时间:2023-12-04 01:40:10 24 4
gpt4 key购买 nike

我正在使用以下代码生成 tfrecords 文件。

  def generate_tfrecords(data_path, labels, name):
"""Converts a dataset to tfrecords."""
filename = os.path.join(args.tfrecords_path, name + '.tfrecords')
writer = tf.python_io.TFRecordWriter(filename)
for index, data in enumerate(data_path):
with tf.gfile.GFile(data, 'rb') as fid:
encoded_jpg = fid.read()
print(len(encoded_jpg)) # 17904
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = pil.open(encoded_jpg_io)
image = np.asarray(image)
print(image.shape) # 112*112*3
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(int(image.shape[0])),
'width': _int64_feature(int(image.shape[1])),
'depth': _int64_feature(int(3)),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(encoded_jpg)}))
writer.write(example.SerializeToString())
writer.close()

在上面的代码中,encoded_jpg 的长度为17904,图像的形状为112*112*3,这是不一致的。

当我使用以下代码解析 tfrecords 时:

def _parse_function(example_proto):
features = {'height': tf.FixedLenFeature((), tf.int64, default_value=0),
'width': tf.FixedLenFeature((), tf.int64, default_value=0),
'depth': tf.FixedLenFeature((), tf.int64, default_value=0),
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
'image_raw': tf.FixedLenFeature((), tf.string, default_value="")}
parsed_features = tf.parse_single_example(example_proto, features)
height = tf.cast(parsed_features["height"], tf.int32) # 112
width = tf.cast(parsed_features["width"], tf.int32) # 112
depth = tf.cast(parsed_features["depth"], tf.int32) #3
label = parsed_features['label']
img = tf.decode_raw(parsed_features['image_raw'], tf.uint8, little_endian=True)
img = tf.reshape(img, [height, width, depth])
return img, label

当我使用上面的代码时,出现了以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 17904 values, but the requested shape has 37632
[[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw, Reshape/shape)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?,?,?], [?]], output_types=[DT_UINT8, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

我该如何解决这个问题。图片类型为png37632=112*112*3。谢谢!

最佳答案

使用 decode_jpeg而不是 decode_raw

关于tensorflow - tf.decode_raw 和 tf.reshape 使用不同的图像大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47502981/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com