gpt4 book ai didi

python - Tensorflow 对象检测 - 避免重叠框

转载 作者:太空宇宙 更新时间:2023-11-03 21:18:06 24 4
gpt4 key购买 nike

简介:我是机器学习的新手,我和一位同事必须实现一种检测交通灯的算法。我下载了一个预先训练的模型(更快的 rcnn)并运行了几个训练步骤(~10000)。现在,当使用 Tensorflow git 存储库中的对象检测算法时,会检测到一个区域中的多个交通灯。

我做了一些研究,发现了函数“tf.image.non_max_suppression”,但我无法让它按预期工作(说实话,我什至无法让它运行)。

我假设您知道 tf 对象检测示例代码,因此您也知道所有框都是使用字典 (output_dict) 返回的。

“清洁”我使用的盒子:

selected_indices = tf.image.non_max_suppression(
boxes = output_dict['detection_boxes'],
scores = output_dict['detection_scores'],
max_output_size = 1,
iou_threshold = 0.5,
score_threshold = float('-inf'),
name = None)

起初我想我可以使用 selected_indices 作为新的框列表,所以我尝试了这个:

vis_util.visualize_boxes_and_labels_on_image_array(
image = image_np,
boxes = selected_indices,
classes = output_dict['detection_classes'],
scores = output_dict['detection_scores'],
category_index = category_index,
instance_masks = output_dict.get('detection_masks'),
use_normalized_coordinates = True)

但是当我注意到这不起作用时,我找到了一个必需的方法:“tf.gather()”。然后我运行了以下代码:

boxes = output_dict['detection_boxes']
selected_indices = tf.image.non_max_suppression(
boxes = boxes,
scores = output_dict['detection_scores'],
max_output_size = 1,
iou_threshold = 0.5,
score_threshold = float('-inf'),
name = None)

selected_boxes = tf.gather(boxes, selected_indices)

vis_util.visualize_boxes_and_labels_on_image_array(
image = image_np,
boxes = selected_boxes,
classes = output_dict['detection_classes'],
scores = output_dict['detection_scores'],
category_index = category_index,
instance_masks = output_dict.get('detection_masks'),
use_normalized_coordinates = True)

但即使这样也不起作用。我在 Visualization_utils.py 的第 689 行收到 AttributeError(“Tensor”对象没有属性“tolist”)。

最佳答案

所以看起来要以正确的格式获取框,您需要创建一个 session 并按如下方式评估张量:

suppressed = tf.image.non_max_suppression(output_dict['detection_boxes'], output_dict['detection_scores'], 5) # Replace 5 with max num desired boxes

sboxes = tf.gather(output_dict['detection_boxes'], suppressed)
sscores = tf.gather(output_dict['detection_scores'], suppressed)
sclasses = tf.gather(output_dict['detection_classes'], suppressed)

sess = tf.Session()
with sess.as_default():
boxes = sboxes.eval()
scores =sscores.eval()
classes = sclasses.eval()

vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
boxes,
classes,
scores,
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
line_thickness=8)

关于python - Tensorflow 对象检测 - 避免重叠框,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54538497/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com