gpt4 book ai didi

python - 如何使用 object_detector.EfficientDetLite4Spec tensorflow lite 继续训练检查点

转载 作者:行者123 更新时间:2023-12-02 16:05:47 27 4
gpt4 key购买 nike

珍贵的是,我在 config.yaml 中设置了我的 EfficientDetLite4 模型 "grad_checkpoint=true"。它已经成功地生成了一些检查点。但是,当我想继续基于它们进行训练时,我无法弄清楚如何使用这些检查点。

每次我训练模型时,它都是从头开始,而不是从我的检查点开始。

下图是我的colab文件系统结构:

my colab file system structure

下图显示了我的检查点存储位置:

model file system here

以下代码展示了我如何配置模型以及如何使用模型进行训练。

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data =
object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
uri='/content/model',
model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
hparams='grad_checkpoint=true,strategy=gpus',
epochs=50, batch_size=3,
steps_per_execution=1, moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3,
train_whole_model=True, validation_data=validation_data)

最佳答案

源码就是答案!

我遇到了同样的问题,发现我们传递给 TFLite 模型制造商的对象检测器 API 的 model_dir 仅用于保存 模型的权重:这就是 API 从不从检查点恢复的原因。

查看此 API 的源代码,我注意到它在内部使用标准的 model.compilemodel.fit 函数,并通过以下方式保存模型的权重model.fitcallbacks 参数。
这意味着,只要我们可以获得内部 keras 模型,我们就可以使用 model.load_weights 恢复我们的检查点!

如果您想更多地了解我在下面使用的某些功能的作用,可以使用这些源代码链接:

这是代码:

#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader

#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib



#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is



#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
model_name='efficientdet-lite4',
uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2',
hparams='', #enable grad_checkpoint=True if you want
model_dir=checkpoint_dir,
epochs=epochs,
batch_size=batch_size,
steps_per_execution=1,
moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25,
strategy=None,
tpu=None,
gcp_project=None,
tpu_zone=None,
use_xla=False,
profile=False,
debug=False,
tf_random_seed=111111,
verbose=1
)



#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')




#Create the object detector
detector = object_detector.create(train_data,
model_spec=spec,
batch_size=batch_size,
train_whole_model=True,
validation_data=validation_data,
epochs = epochs,
do_train = False
)



"""
From here on we use internal/"private" functions of the API,
you can tell because the methods' names begin with an underscore
"""

#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)




#Get the internal keras model
model = detector.create_model()




#Copy what the API internally does as setup
config = spec.config
config.update(
dict(
steps_per_epoch=steps_per_epoch,
eval_samples=batch_size * validation_steps,
val_json_file=val_json_file,
batch_size=batch_size
)
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()




"""
Here we restore the weights
"""

#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/"
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
#Option A:
#load the weights from the last successfully completed epoch
latest = tf.train.latest_checkpoint(checkpoint_dir)

#Option B:
#load the weights from a specific checkpoint
#latest = specific_checkpoint_dir

completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
model.load_weights(latest)

print("Checkpoint found {}".format(latest))
except Exception as e:
print("Checkpoint not found: ", e)



#Retrieve the needed default callbacks
all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)



"""
Optional step.
Add callbacks that get executed at the end of every N
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#all_callbacks.append(tensorboard_callback)




"""
Train the model
"""
model.fit(
train_ds,
epochs=epochs,
initial_epoch=completed_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_ds,
validation_steps=validation_steps,
callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks
)




"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY
the quantization_config parameter of the detector.export method.
In this case it would be:
detector.export(export_dir = export_dir, tflite_filename='model.tflite')
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)

关于python - 如何使用 object_detector.EfficientDetLite4Spec tensorflow lite 继续训练检查点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69444878/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com