在树莓派上配置基于tf的Object Detection应用

/ 0评 / 0

所用镜像是这个.

进系统第一时间先把来自冥古宙版本的pip3升级一下,再更新一下系统,再根据TF空间大小,申请一块更大的SWAP(我设置了2G的额外SWAP).

curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py

sudo python3 get-pip.py

sudo pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

sudo pip3 install --upgrade setuptools

sudo apt update

sudo apt dist-upgrade

因为Tensorflow在墙外,所以自己先下载,然后安装,由于piwheels没有镜像,而且包数量比较多,速度又比较慢,所以安装过程有点慢,到安装通用包时候就快了.(负载极高,注意散热.)

安装完Tensorflow,可以继续安装其他支持组件.

sudo apt install libatlas-base-dev libxslt-dev gawk vim

sudo pip3 install pillow lxml jupyter matplotlib cython

sudo apt install python-tk

sudo apt install libjpeg-dev libtiff5-dev libjasper-dev libpng12-dev

sudo apt install libavcodec-dev libavformat-dev libswscale-dev libv4l-dev

sudo apt install libxvidcore-dev libx264-dev

sudo apt install qt4-dev-tools

sudo pip3 install opencv-python

sudo apt install autoconf automake libtool curl

编译安装Protobuf,因为apt源里面的Protobuf老掉牙了,不用他,我们自己编译.

具体链接:https://github.com/protocolbuffers/protobuf/releases

编译方法很简单.(就是比较吃时间,大约花了2小时,包括make check检查时间,如果中途中断,则需要make clean后重新开始,并且这个代码不能包含多线程编译参数.)

./configure

make

make check

sudo make install

拉Tensorflow的一些测试模型下来,建议使用代理,比如我指向了局域网的某特殊代理服务器.

git config --global http.proxy 'socks5://127.0.0.1:1080'

git config --global https.proxy 'socks5://127.0.0.1:1080'

git clone --recurse-submodules https://github.com/tensorflow/models.git

然后给bashrc添加环境变量声明.

export PYTHONPATH=$PYTHONPATH:/home/pi/models/research:/home/pi/models/research/slim

再模型目录内编译proto文件.

protoc object_detection/protos/*.proto --python_out=.

下载一些已训练好的模型.

参考链接:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

我选了一个精简模型.

我写了个脚本,实时摄像头跟踪的.(Python3)

import os

import cv2

import numpy as np

import tensorflow as tf

import argparse

import sys



IM_WIDTH = 1280

IM_HEIGHT = 720



# 工作目录:object_detection

sys.path.append('..')



from utils import label_map_util

from utils import visualization_utils as vis_util



# 已训练的模型文件

MODEL_NAME = 'ssdlite_mobilenet_v2_coco_2018_05_09'



# 获取当前的路径

CWD_PATH = os.getcwd()



# 模型 + 模型查找表

PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb')

PATH_TO_LABELS = os.path.join(CWD_PATH,'data','mscoco_label_map.pbtxt')



# 最大扫描层数

NUM_CLASSES = 90



# 加载表

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)

categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)

category_index = label_map_util.create_category_index(categories)



# 读入图,创建会话.

detection_graph = tf.Graph()

with detection_graph.as_default():

    od_graph_def = tf.GraphDef()

    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:

        serialized_graph = fid.read()

        od_graph_def.ParseFromString(serialized_graph)

        tf.import_graph_def(od_graph_def, name='')



    sess = tf.Session(graph=detection_graph)





# 输入类型是标准图

image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')



# 输出类型是分布矩阵

detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')



# 层次关系是权重优先

detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')

detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')



# 检测数量

num_detections = detection_graph.get_tensor_by_name('num_detections:0')



# 初始化摄像头

camera = cv2.VideoCapture(0)

ret = camera.set(3,IM_WIDTH)

ret = camera.set(4,IM_HEIGHT)



while(True):

    # 数据为一阶张量: [1, None, None, 3]

    ret, frame = camera.read()

    frame_expanded = np.expand_dims(frame, axis=0)



    # 重入会话

    (boxes, scores, classes, num) = sess.run(

        [detection_boxes, detection_scores, detection_classes, num_detections],

        feed_dict={image_tensor: frame_expanded})



    # 在被检测物体上绘图

    vis_util.visualize_boxes_and_labels_on_image_array(

        frame,

        np.squeeze(boxes),

        np.squeeze(classes).astype(np.int32),

        np.squeeze(scores),

        category_index,

        use_normalized_coordinates=True,

        line_thickness=8,

        min_score_thresh=0.85)



    cv2.imshow('Pi ML-NN', frame)    



    # 按X退出

    if cv2.waitKey(1) == ord('x'):

        break



camera.release()

cv2.destroyAllWindows()

在图形界面上运行这个程序,接入USB摄像头,即可进行实时检测.

打开这个程序需要的时间也长一些,大功告成.

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注