gpt4 book ai didi

python - Tensorflow:具有任意维度张量的批量 TFRecord 数据集

转载 作者:行者123 更新时间:2023-12-04 15:53:20 32 4
gpt4 key购买 nike

如何使用 TFRecordsDataset 对任意形状的张量进行批处理?

我目前正在研究对象检测网络的输入管道,并且正在为标签的批处理而苦苦挣扎。标签由边界框坐标和图像中对象的类别组成。由于图像中可能有多个对象,因此标签尺寸是任意的


使用 tf.train.batch 时,可以设置 dynamic_padding=True 使形状适合相同的尺寸。但是 data.TFRecordDataset.batch() 中没有这样的选项。

我想要批处理的所需形状是 [batch_size, arbitrary , 4] 用于我的 Boxes 和 [batch_size, arbitrary, 1] 用于类。

def decode(serialized_example):
"""
Decodes the information of the TFRecords to image, label_coord, label_classes
Later on will also contain the Image Sequence!

:param serialized_example: Serialized Example read from the TFRecords
:return: image, label_coordinates list, label_classes list
"""
features = {'image/shape': tf.FixedLenFeature([], tf.string),
'train/image': tf.FixedLenFeature([], tf.string),
'label/coordinates': tf.VarLenFeature(tf.float32),
'label/classes': tf.VarLenFeature(tf.string)}

features = tf.parse_single_example(serialized_example, features=features)

image_shape = tf.decode_raw(features['image/shape'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.float32)
image = tf.reshape(image, image_shape)

# Contains the Bounding Box coordinates in a flattened tensor
label_coord = features['label/coordinates']
label_coord = label_coord.values
label_coord = tf.reshape(label_coord, [1, -1, 4])

# Contains the Classes of the BBox in a flattened Tensor
label_classes = features['label/classes']
label_classes = label_classes.values
label_classes = tf.reshape(label_classes, [1, -1, 1])


return image, label_coord, label_classes

    dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)

dataset = dataset.batch(batch_size)

抛出的错误是无法在分量 1 中批量处理具有不同形状的张量。第一个元素的形状为 [1,1,4],元素 1 的形状为 [1,7,4]。

此外,目前 augmentnormalize 函数只是占位符。

最佳答案

事实证明 tf.data.TFRecordDataset 有一个叫做 padded_batch 的函数,它基本上是做 tf.train.batch(dynamic_pad=True) 确实如此。这很容易解决问题......

dataset = tf.data.TFRecordDataset(filename)

dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)

dataset = dataset.shuffle(1000+3*batch_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(batch_size,
drop_remainder=False,
padded_shapes=([None, None, None],
[None, 4],
[None, 1])
)

关于python - Tensorflow:具有任意维度张量的批量 TFRecord 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52965004/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com