gpt4 book ai didi

object-detection-api - 使用 tensorflow 对象检测 api 训练 CenterNet MobileNetV2 FPN 512x512 时出现错误 "indices[0] = 0 not in [0,0)"

转载 作者:行者123 更新时间:2023-12-04 08:01:10 24 4
gpt4 key购买 nike

我正在根据本指南运行 tensorflow 对象检测 api https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#configuring-a-training-job但是,对制作记录文件的代码和以下系统进行了稍微修改:

系统信息:

  • 操作系统平台和发行版:Ubuntu 20.04.1 LTS
  • Python 版本:
  • TensorFlow 版本:2.4.0
  • CUDA/cuDNN 版本:11.0/8.0.5
  • GPU 型号和内存:GeForce RTX 3090、24268 MiB

我想将模型 CenterNet MobileNetV2 FPN 512x512 用于来自 TensorFlow2 Detection Model Zoo (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) 的盒子。

我根据指南设置训练作业然后运行

python model_main_tf2.py --model_dir=models/my_centernet_mn_fpn --pipeline_config_path=models/my_centernet_mn_fpn/pipeline.config

在执行此操作时出现以下错误

Traceback (most recent call last):
File "model_main_tf2.py", line 115, in <module>
tf.compat.v1.app.run()
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "model_main_tf2.py", line 106, in main
model_lib_v2.train_loop(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/object_detection/model_lib_v2.py", line 636, in train_loop
loss = _dist_train_step(train_input_iter)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
return graph_function._call_flat(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
(1) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
[[ToAbsoluteCoordinates_42/Assert/AssertGuard/branch_executed/_386/_1231]]
0 successful operations.
0 derived errors ignored. [Op:__inference__dist_train_step_54439]

Function call stack:
_dist_train_step -> _dist_train_step

当谷歌搜索这个错误时,会得到一些答案,错误是在创建 TFRecord 文件时,你需要在创建它们时添加 include_masks。但是,当从 model zoo 运行其他 CenterNet 模型时,我没有收到此错误,所以这会是错误似乎很奇怪。

有什么想法吗?

最佳答案

感谢亚历山德拉。从配置文件中删除关键点关联。

原代码(删除标有“//delete”的部分):

model {
center_net {
num_classes: 90
feature_extractor {
type: "mobilenet_v2_fpn_sep_conv"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
use_depthwise: true
object_detection_task {
task_loss_weight: 1.0
offset_loss_weight: 1.0
scale_loss_weight: 0.1
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
object_center_loss_weight: 1.0
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
min_box_overlap_iou: 0.7
max_box_predictions: 20
}
}
}
train_config {
batch_size: 512
data_augmentation_options { // delete
random_horizontal_flip {
keypoint_flip_permutation: 0
keypoint_flip_permutation: 2
keypoint_flip_permutation: 1
keypoint_flip_permutation: 4
keypoint_flip_permutation: 3
keypoint_flip_permutation: 6
keypoint_flip_permutation: 5
keypoint_flip_permutation: 8
keypoint_flip_permutation: 7
keypoint_flip_permutation: 10
keypoint_flip_permutation: 9
keypoint_flip_permutation: 12
keypoint_flip_permutation: 11
keypoint_flip_permutation: 14
keypoint_flip_permutation: 13
keypoint_flip_permutation: 16
keypoint_flip_permutation: 15
}
} // delete
data_augmentation_options {
random_patch_gaussian {
}
}
data_augmentation_options {
random_crop_image {
min_aspect_ratio: 0.5
max_aspect_ratio: 1.7
random_coef: 0.25
}
}
data_augmentation_options {
random_adjust_hue {
}
}
data_augmentation_options {
random_adjust_contrast {
}
}
data_augmentation_options {
random_adjust_saturation {
}
}
data_augmentation_options {
random_adjust_brightness {
}
}
data_augmentation_options {
random_absolute_pad_image {
max_height_padding: 200
max_width_padding: 200
pad_color: 0.0
pad_color: 0.0
pad_color: 0.0
}
}
optimizer {
adam_optimizer {
learning_rate {
cosine_decay_learning_rate {
learning_rate_base: 5e-3
total_steps: 300000
warmup_learning_rate: 1e-4
warmup_steps: 5000
}
}
}
use_moving_average: false
}
num_steps: 300000
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
fine_tune_checkpoint_type: ""
}
train_input_reader {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/train2017-?????-of-00256.tfrecord"
}
filenames_shuffle_buffer_size: 256
num_keypoints: 17 // delete
}
eval_config {
num_visualizations: 10
metrics_set: "coco_detection_metrics"
use_moving_averages: false
min_score_threshold: 0.20000000298023224
max_num_boxes_to_visualize: 20
batch_size: 1
}
eval_input_reader {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/val2017-?????-of-00032.tfrecord"
}
num_keypoints: 17 // delete
}

下面的代码是我的工作方式(修改):

model {
center_net {
num_classes: 4
feature_extractor {
type: "mobilenet_v2_fpn_sep_conv"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
use_depthwise: true
object_detection_task {
task_loss_weight: 1.0
offset_loss_weight: 1.0
scale_loss_weight: 0.1
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
object_center_loss_weight: 1.0
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
min_box_overlap_iou: 0.7
max_box_predictions: 20
}
}
}
train_config {
batch_size: 8
data_augmentation_options {
random_patch_gaussian {
}
}
data_augmentation_options {
random_crop_image {
min_aspect_ratio: 0.5
max_aspect_ratio: 1.7
random_coef: 0.25
}
}
data_augmentation_options {
random_adjust_hue {
}
}
data_augmentation_options {
random_adjust_contrast {
}
}
data_augmentation_options {
random_adjust_saturation {
}
}
data_augmentation_options {
random_adjust_brightness {
}
}
data_augmentation_options {
random_absolute_pad_image {
max_height_padding: 200
max_width_padding: 200
pad_color: 0.0
pad_color: 0.0
pad_color: 0.0
}
}
optimizer {
adam_optimizer {
learning_rate {
cosine_decay_learning_rate {
learning_rate_base: 5e-3
total_steps: 50000
warmup_learning_rate: 1e-4
warmup_steps: 2000
}
}
}
use_moving_average: false
}
num_steps: 6000
max_number_of_boxes: 4
unpad_groundtruth_tensors: false
fine_tune_checkpoint_version: V2
fine_tune_checkpoint: "pre-trained-models/centernet_mobilenetv2_fpn_od/checkpoint/ckpt-301"
fine_tune_checkpoint_type: "detection"
}
train_input_reader {
label_map_path: "annotations/label_map.pbtxt"
tf_record_input_reader {
input_path: "annotations/train_c.record"
}
filenames_shuffle_buffer_size: 256
}
eval_config {
num_visualizations: 10
metrics_set: "coco_detection_metrics"
use_moving_averages: false
min_score_threshold: 0.20000000298023224
max_num_boxes_to_visualize: 20
batch_size: 1
}
eval_input_reader {
label_map_path: "annotations/label_map.pbtxt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "annotations/test_c.record"
}
}

关于object-detection-api - 使用 tensorflow 对象检测 api 训练 CenterNet MobileNetV2 FPN 512x512 时出现错误 "indices[0] = 0 not in [0,0)",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66457432/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com