gpt4 book ai didi

python - 用于对象检测和分割的 Mask R-CNN [训练自定义数据集]

转载 作者:行者123 更新时间:2023-11-30 09:25:13 24 4
gpt4 key购买 nike

我正在研究“用于对象检测和分割的Mask R-CNN”。因此,我阅读了原始研究论文,其中介绍了用于对象检测的Mask R-CNN,并且我还发现了Mask R-CNN的一些实现,herehere (由 Facebook AI 研究团队称为 detectorron)。但他们都使用了coco数据集进行测试。

但是我对使用自定义数据集训练上述实现感到相当困惑,该数据集具有大量图像,并且对于每个图像都有一个掩码图像子集用于标记相应图像中的对象。

因此,如果有人可以为此任务发布有用的资源或代码示例,我会很高兴。

注意:我的数据集具有以下结构,

It consists with a large number of images and for each image, there are separate image files highlighting the object as a white patch in a black image.

这是一个示例图像及其蒙版:

图像;

enter image description here

面具;

enter image description here enter image description here

最佳答案

我已经训练了 https://github.com/matterport/Mask_RCNN 的实例分割模型以在我的数据集上运行。

我的假设是您已完成所有基本设置,并且模型已经使用默认数据集(在存储库中提供)运行,现在您希望它针对自定义数据集运行。

以下是步骤

  1. 您需要拥有所有注释。
  2. 所有这些都需要转换为 VGG 多边形模式(是的,我的意思是多边形,即使您需要绑定(bind)框)。我在这个答案的末尾添加了一个示例 VGG 多边形格式。
  3. 您需要将自定义数据集分为训练集、测试集和验证集
  4. 默认情况下,注释会在各个数据集文件夹内使用文件名 via_region_data.json 进行查找。例如,对于训练图像,它将查看train\via_region_data.json。如果您愿意,也可以更改它。
  5. 在 Samples 文件夹中,您可以找到 Balloon、Nucleus、Shapes 等文件夹。复制其中一个文件夹。最好是气球。我们现在将尝试为我们的自定义数据集修改这个新文件夹。
  6. 在复制的文件夹中,您将有一个 .py 文件(对于气球,它将是气球.py),更改以下变量
    • ROOT_DIR :克隆项目的绝对路径
    • DEFAULT_LOGS_DIR:此文件夹的大小将变得更大,因此请相应地更改此路径(如果您在低磁盘存储虚拟机中运行代码)。它还将存储 .h5 文件。它将在日志文件夹内创建子文件夹,并附加时间戳。
    • 每个时期的
    • .h5 文件大约为 200 - 300 MB。但猜猜这个日志目录与 Tensorboard 兼容。您可以在运行tensorboard时将带时间戳的子文件夹作为--logdir参数传递。
  7. .py 文件还包含两个类 - 一个后缀为 Config 的类,另一个后缀为 Dataset 的类。
  8. 在 Config 类中覆盖所需的内容,例如
    • NAME:您的项目的名称。
    • NUM_CLASSES:它应该比您的标签类别多一个,因为背景也被视为一个标签
    • DETECTION_MIN_CONFIDENCE:默认为 0.9(如果您的训练图像质量不是很高或者没有太多训练数据,请降低该值)
    • STEPS_PER_EPOCH
  9. 在 Dataset 类中重写以下方法。所有这些功能都已经有很好的注释,因此您可以按照注释根据您的需要进行覆盖。
    • load_(name_of_the_sample_project) 例如 load_balloon
    • load_mask(请参阅示例答案的最后一个)
    • 图片引用
  10. 训练函数(数据集类之外):如果您必须更改纪元数或学习率等

您现在可以直接从终端运行它

python samples\your_folder_name\your_python_file_name.py train --dataset="location_of_custom_dataset" --weights=coco

有关上述行命令行参数的完整信息,您可以将其视为此 .py 文件顶部的注释。

这些是我能记得的事情,我想补充更多我记得的步骤。如果您在任何特定步骤中遇到困难,也许您可​​以告诉我,我会详细说明该特定步骤。

VGG 多边形架构

宽度和高度是可选的

[{
"filename": "000dfce9-f14c-4a25-89b6-226316f557f3.jpeg",
"regions": {
"0": {
"region_attributes": {
"object_name": "Cat"
},
"shape_attributes": {
"all_points_x": [75.30864197530865, 80.0925925925926, 80.0925925925926, 75.30864197530865],
"all_points_y": [11.672189112257607, 11.672189112257607, 17.72093488703078, 17.72093488703078],
"name": "polygon"
}
},
"1": {
"region_attributes": {
"object_name": "Cat"
},
"shape_attributes": {
"all_points_x": [80.40123456790124, 84.64506172839506, 84.64506172839506, 80.40123456790124],
"all_points_y": [8.114103362391036, 8.114103362391036, 12.205901974737595, 12.205901974737595],
"name": "polygon"
}
}
},
"width": 504,
"height": 495
}]

load_mask 函数示例

def load_mask(self, image_id):
"""Generate instance masks for an image.
Returns:
masks: A bool array of shape [height, width, instance count] with
one mask per instance.
class_ids: a 1D array of class IDs of the instance masks.
"""
# If not your dataset image, delegate to parent class.
image_info = self.image_info[image_id]
if image_info["source"] != "name_of_your_project": //change your project name
return super(self.__class__, self).load_mask(image_id)

# Convert polygons to a bitmap mask of shape
# [height, width, instance_count]
info = self.image_info[image_id]
mask = np.zeros([info["height"], info["width"], len(info["polygons"])], dtype=np.uint8)
class_id = np.zeros([mask.shape[-1]], dtype=np.int32)

for i, p in enumerate(info["polygons"]):
# Get indexes of pixels inside the polygon and set them to 1
rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
# print(rr.shape, cc.shape, i, np.ones([mask.shape[-1]], dtype=np.int32).shape, info['classes'][i])

class_id[i] = self.class_dict[info['classes'][i]]
mask[rr, cc, i] = 1


# Return mask, and array of class IDs of each instance. Since we have
# one class ID only, we return an array of 1s
return mask.astype(np.bool), class_id

关于python - 用于对象检测和分割的 Mask R-CNN [训练自定义数据集],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49684468/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com