写这篇文章的目的就是单纯记录下我对faster_rcnn的理解,主要分为原理和代码两部分,写这篇文章之前首先要感谢下@白裳的文章,这里贴出链接地址:
白裳:一文读懂Faster RCNN9612 赞同 · 411 评论文章
原理
下面贴出原论文中的结构图:
图1
从图1可以看出Faster_rcnn主要由四个部分组成,分别为:特征提取网络,RPN网络、ROI Pooling以及最终的分类和定位网络。
其整体流程为:图像预处理---特征提取网络---利用RPN网络获得调整参数(先验框与真实框之间的差:用网络预测的值表示)、每一个先验框的得分(前景还是背景)、以及RPN网络的建议框(该建议框按照每一个先验框的概率提取得到、按照得分进行排序,排序完之后进行NMS抑制最后得到一定数量的建议框)---计算先验框的回归损失和分类损失---从RPN网络中生成的建议框选择一定的数量的建议框送入ROI Pooling层---ROI Pooling层选出建议框在特征提取网络得到的base_feature中对应位置的特征(涉及到尺寸变换),利用该特征得到分类和回归网络的输出值---计算分类和回归网络的损失值。
主要流程就是以上部分,我认为该网络主要就是围绕先验框、建议框和真实框三个框之间做文章,先验框的生成如下图所示:
图2
其中,先验框是在图像特征提取网络得到的特征图(base_feature)上面生成的,也就是上图中的conv feature map,为该特征图上面的每一个点生成k个先验框(anchor boxes),这里的k一般设置为9,分别有三个尺寸、每个尺寸有三个图,尺寸分别为1:1、1:2和2:1。
生成先验框的样式为:
图3
其中每一个先验框的大小和类别由该特征图上面的点通过特征提取得到,如图2所示,其中cls layer和reg layer分别为得到先验框分数(判断该先验框为前景还是背景)和大小(包括四个点分别为左上角和右下角xy坐标)的层。
那么RPN网络做的事情是什么呢?
很显然,我们的先验框数量是冗余的,与实际真实框的数量相差甚远,RPN网络的功能就是从这么多数量的先验框中挑选出合适的建议框,这里选择的操作主要是:过滤不符合条件的先验框(超出图像边缘、宽高不能超过scale size),过滤完之后按照先验框的得分即softmax(results of cls layer)(包含物体的概率)排序取出一定数量的先验框,之后对取出的先验框进行NMS(非极大值抑制)操作,得到合适的建议框。
我们怎么才能让RPN网络工作的越来越好呢?
答案显然是使用损失函数来约束网络的优化方向,那么针对RPN网络的损失函数分为两部分,即分类损失和回归损失。分类损失是针对先验框是前景还是背景设计的,回归损失是针对先验框的位置设计的。
分类的label来源:首先计算每个先验框最对应的真实框的序号、每个先验框最对应的真实框的iou值、每一个真实框最对应的先验框的序号,之后与将每个先验框最对应的真实框的iou值与设置的门限进行比较,小于门限值设置为负样本(背景),大于门限值设置为正样本(前景),并将正负样本数量限制在一定范围内,每个先验框的分类标签就这么得到了(主要就是利用先验框与真实框的IOU值来确定该先验框的种类:前景还是背景?)回归损失的计算对象为:每个先验框最对应的真实框之间的差距与RPN网络得到的每个先验框的坐标的预测值,使得先验框与真实框之间的差距逐渐减小。据此,我们就得到置信度比较高的先验框以及RPN网络建议的建议框!!!!!!
得到PRN网络建议的建议框之后还需要对其进行采样,以便最后的分类和回归网络进行处理,采样的方式为:计算每一个建议框最对应的真实框,每一个建议框最对应的真实框的IOU,每一个建议框的标签,将建议框与真实框的IOU值满足一定阈值的建议框作为正样本(具体的物体类别),不满足该条件的建议框作为负样本(背景),据此得出采样后的建议框(记为Sample建议框)的坐标值和标签(具体的物体类别)。
得到Sample建议框之后,我们的ROIPooling开始发挥作用了,利用Sample建议框的坐标值去截取base_feature上对应的特征,利用该特征可以得到Faster_rcnn网络的预测结果(包括结果框的坐标预测结果和结果框内的物体的类别的预测结果)之后对该结果计算相应的回归损失和分类损失,分类损失的label同Sample建议框的标签,回归损失的label即真实框的坐标。
至此,整个Faster_rcnn的原理以及代码流程就写完了,好累!!!!!!!!!!!!
测试的代码还没看完,按照上述流程得到的模型测试的时候会有很多结果框,我认为网络最终的预测结果是最对应真实框的结果框,并给出其物体分类的概率!!!!
下面放上一张预训练好的模型在网图上的结果:
猫狗都分不出来,哈哈!
测试代码
"""
Faster rcnn实现目标检测
"""
import os
import time
import torch
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
# 获取当前路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classes_coco类别信息
COCO_INSTANCE_CATEGORY_NAMES = [
__background__, person, bicycle, car, motorcycle, airplane, bus,
train, truck, boat, traffic light, fire hydrant, N/A, stop sign,
parking meter, bench, bird, cat, dog, horse, sheep, cow,
elephant, bear, zebra, giraffe, N/A, backpack, umbrella, N/A, N/A,
handbag, tie, suitcase, frisbee, skis, snowboard, sports ball,
kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket,
bottle, N/A, wine glass, cup, fork, knife, spoon, bowl,
banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza,
donut, cake, chair, couch, potted plant, bed, N/A, dining table,
N/A, N/A, toilet, N/A, tv, laptop, mouse, remote, keyboard, cell phone,
microwave, oven, toaster, sink, refrigerator, N/A, book,
clock, vase, scissors, teddy bear, hair drier, toothbrush
]
print(len(COCO_INSTANCE_CATEGORY_NAMES))
if __name__ == "__main__":
# 检测图片路径
path_img = os.path.join(BASE_DIR, "R-C.jpg")
# 预处理
preprocess = transforms.Compose([
transforms.ToTensor(),
])
input_image = Image.open(path_img).convert("RGB")
img_chw = preprocess(input_image)
# 加载预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
if torch.cuda.is_available():
img_chw = img_chw.to(cuda)
model.to(cuda)
# 前向传播
input_list = [img_chw]
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_list[0].shape))
output_list = model(input_list)
output_dict = output_list[0]
print("pass: {:.3f}s".format(time.time() - tic))
# 打印输出信息
for k, v in output_dict.items():
print("key:{}, value:{}".format(k, v))
# 取得相应结果
out_boxes = output_dict["boxes"].cpu()
out_scores = output_dict["scores"].cpu()
out_labels = output_dict["labels"].cpu()
# 可视化
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(input_image, aspect=equal)
num_boxes = out_boxes.shape[0]
print(num_boses:,num_boxes)
max_vis = 400
thres = 0.6
# 循环描框
for idx in range(0, min(num_boxes, max_vis)):
score = out_scores[idx].numpy()
bbox = out_boxes[idx].numpy()
class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]
if score < thres:
continue
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
edgecolor=red, linewidth=3.5))
ax.text(bbox[0], bbox[1] - 2, {:s} {:.3f}.format(class_name, score), bbox=dict(facecolor=blue, alpha=0.5),
fontsize=14, color=white)
ax.set_title("just a simple try about Faster Rcnn", fontsize=28, color=blue)
plt.show()
plt.close()
注释丰富的完整代码
地址:
https://github.com/bubbliiiing/faster-rcnn-pytorchgithub.com/bubbliiiing/faster-rcnn-pytorch
B站讲解视频:
https://www.bilibili.com/video/BV1BK41157Vs?spm_id_from=333.999.0.0www.bilibili.com/video/BV1BK41157Vs?spm_id_from=333.999.0.0