gpt4 book ai didi

python - InvalidArgumentError : Key: label. 无法解析序列化示例 : How can I find a way to parse the one-hot encoded labels from TFRecords?

转载 作者:行者123 更新时间:2023-12-05 06:54:05 26 4
gpt4 key购买 nike

我有 12 个包含图像的文件夹(它们是我的数据类别)。此代码将图像及其相应标签转换为 tfrecord 数据并有效压缩:

import tensorflow as tf
from pathlib import Path
from tensorflow.keras.utils import to_categorical
import cv2
from tqdm import tqdm
from os import listdir
import numpy as np
import matplotlib.image as mpimg
from tqdm import tqdm

labels = {v:k for k, v in enumerate(listdir('train/'))}
labels

class GenerateTFRecord:
def __init__(self, path):
self.path = Path(path)
self.labels = {v:k for k, v in enumerate(listdir(path))}

def convert_image_folder(self, tfrecord_file_name):
# Get all file names of images present in folder
img_paths = list(self.path.rglob('*.jpg'))

with tf.io.TFRecordWriter(tfrecord_file_name) as writer:
for img_path in tqdm(img_paths, desc='images converted'):
example = self._convert_image(img_path)
writer.write(example.SerializeToString())

def _convert_image(self, img_path):
label = self.labels[img_path.parent.stem]
img_shape = mpimg.imread(img_path).shape

# Read image data in terms of bytes
with tf.io.gfile.GFile(img_path, 'rb') as fid:
image_data = fid.read()

example = tf.train.Example(features = tf.train.Features(feature = {
'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = tf.one_hot(label, depth=len(labels), on_value=1, off_value=0))),
}))
return example

t = GenerateTFRecord(path='train/')
t.convert_image_folder('data.tfrecord')

然后我在这里使用这段代码读取 tfrecord 数据并创建我的 tf.data.Dataset:

def _parse_function(tfrecord):
# Extract features using the keys set during creation
features = {
'rows': tf.io.FixedLenFeature([], tf.int64),
'cols': tf.io.FixedLenFeature([], tf.int64),
'channels': tf.io.FixedLenFeature([], tf.int64),
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}

# Extract the data record
sample = tf.io.parse_single_example(tfrecord, features)

image = tf.image.decode_image(sample['image'])
label = sample['label']
# label = tf.one_hot(label, depth=len(labels), on_value=1, off_value=0)
return image, label

def configure_for_performance(ds, buffer_size, batch_size):
ds = ds.cache()
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=buffer_size)
return ds


def generator(tfrecord_file, batch_size, n_data, validation_ratio, reshuffle_each_iteration=False):
reader = tf.data.TFRecordDataset(filenames=[tfrecord_file])
reader.shuffle(n_data, reshuffle_each_iteration=reshuffle_each_iteration)
AUTOTUNE = tf.data.experimental.AUTOTUNE

val_size = int(n_data * validation_ratio)
train_ds = reader.skip(val_size)
val_ds = reader.take(val_size)

train_ds = train_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
train_ds = configure_for_performance(train_ds, AUTOTUNE, batch_size)

val_ds = val_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
val_ds = configure_for_performance(val_ds, AUTOTUNE, batch_size)
return train_ds, val_ds

在这里我创建了我的模型:

from os.path import isdir, dirname, abspath, join
from os import makedirs

from tensorflow.keras import Sequential
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import SGD, Adam


def create_model(optimizer, freeze_layer=False):
densenet = DenseNet121(weights='imagenet',
include_top=False)

if freeze_layer:
for layer in densenet_model.layers:
if 'conv5' in layer.name:
layer.trainable = True
else:
layer.trainable = False

model = Sequential()
model.add(densenet)
model.add(GlobalAveragePooling2D())
model.add(Dense(12, activation='softmax'))

model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

return model

if __name__ == '__main__':
optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=1e-6)
densenet_model = create_model(optimizer)

tfrecord_file = 'data.tfrecord'
n_data = len(list(Path('train').rglob('*.jpg')))
train, val = generator(tfrecord_file, 2, n_data, validation_ratio, True)

validation_ratio = 0.2
val_size = int(n_data * validation_ratio)
train_size = n_data - val_size
batch_size = 32
n_epochs = 300
n_workers = 5

filename = '/content/drive/MyDrive/data.tfrecord'


train_ds, val_ds = generator(filename,
batch_size=batch_size,
n_data=n_data,
validation_ratio=validation_ratio,
reshuffle_each_iteration=True)


hist = densenet_model.fit(train_ds,
validation_data=val_ds,
epochs=n_epochs,
workers=n_workers,
steps_per_epoch=train_size//batch_size,
validation_steps=val_size)

这是我每次得到的错误:

InvalidArgumentError:键:标签。无法解析序列化示例。 [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]] [[IteratorGetNext]] [Op:__inference_train_function_343514]

显然我的 tfrecord 数据中的 label 有问题。

我真的需要知道,根据我的模型输出形状 (12,),我如何才能安全地将一个热编码标签存储在我的 tfrecord 中并在 tf.data.Dataset 中进行解析?

谢谢大家

最佳答案

如答案所示here数据数组应该是固定大小的,所以我认为它可以解决你的问题。

关于python - InvalidArgumentError : Key: label. 无法解析序列化示例 : How can I find a way to parse the one-hot encoded labels from TFRecords?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65620220/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com