gpt4 book ai didi

python - 使用opencv,tensorflow和python进行人体检测

转载 作者:行者123 更新时间:2023-12-02 16:45:21 24 4
gpt4 key购买 nike

我正在进行一个机器人项目,涉及检测人体,为此我正在使用 tensorflow 和预定义的数据集来创建训练模型。由于我是机器学习的新手,因此我无法从分类器中正确获取输出。我只需要检测“人”,并希望避免检测到球,笔记本电脑或其他物体。
现在,我的网络摄像头可以检测到所有物体,例如球,球棒,笔记本电脑,电视等。我需要的输出仅是得分达到80%或以上的人。

我用于创建模型的代码是

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image


from utils import label_map_util

from utils import visualization_utils as vis_util

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'


PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90

if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
print ('Downloading the model')
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
print ('Download complete')
else:
print ('Model already exists')

detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

import cv2
cap = cv2.VideoCapture(1)


with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
ret = True
while (ret):
ret,image_np = cap.read()
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')

(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=8)
cv2.imshow('image',cv2.resize(image_np,(1280,960)))
if cv2.waitKey(27) & 0xFF == ord('q'):
cv2.destroyAllWindows()
cap.release()
break

谁能解释我如何只检测准确度得分超过80%的人。

最佳答案

从文档here可以看到,您只需要检查person类。现在,vis_util检查所有类。您只需要为person类添加if条件。下面给出的是适当的标识符(来自docs)。
item {
name: "/m/01g317"
id: 1
display_name: "person"
}

关于python - 使用opencv,tensorflow和python进行人体检测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50642057/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com