专栏名称: OpenCV学堂
一个致力于计算机视觉OpenCV原创技术传播的公众号!OpenCV计算机视觉与tensorflow深度学习相关算法原创文章分享、函数使用技巧、源码分析与讨论、,计算机视觉前沿技术介绍,技术专家经验分享,人才交流,学习交流。
目录
相关文章推荐
电商报Pro  ·  DeepSeek公布运营成本利润,行业大为震惊 ·  昨天  
高分子科学前沿  ·  加拿大阿尔伯塔大学曾宏波等ACS ... ·  昨天  
蛋先生工作室  ·  2025年3月2日最新蛋价(早报) ·  2 天前  
高分子科技  ·  广工邱学青/林绪亮团队 Adv. ... ·  4 天前  
高分子科学前沿  ·  江南大学刘天西/张龙生《AM》:配位组装-诱 ... ·  3 天前  
51好读  ›  专栏  ›  OpenCV学堂

Torchvision框架学习之FCOS模型及其训练

OpenCV学堂  · 公众号  ·  · 2024-11-20 22:34

正文

点击上方 蓝字 关注我们

微信公众号: OpenCV学堂

关注获取更多计算机视觉与深度学习知识

FCOS介绍

FCOS(Fully Convolutional One Stage)是于2019年由Zhi Tian等人提出的一种全卷积单阶段无锚框检测模型。该模型无需在训练时计算锚框的IoU值,极大的简化了模型的训练,使得训练方便,并且在推理时只有NMS(None Maximum Suppression,非及大值抑制),使得推理速度极快。在torchvision模型库中包含了以Resnet50作为骨干网络的FCOS,这样只需要创建模型进行训练即可。

模型介绍

FCOS模型的结构如图10.8所示,整个模型从左向右,可以分为三个部分:Backbone(骨干网络),Feature Pyramid(特征金字塔),以及由Classification、Center-ness和Regression构成的多尺度目标检测头作为输出。


Backbone主要用于提取提取不同尺度的特征,一般使用预训练的分类网络的特征提取部分作为BackBone,FCOS的Backbone,作为模型的入口,接受输入的图像,从Backbone中输出C3,C4和C5特征图。各特征图的高和宽标记在各特征图的左侧,分别为输入图像的1/8,1/16和1/32,对于输入为800×1024的图像,C3的尺寸为100×128,C4的尺寸为50×64,C5的尺寸为25×32。

Feature Pyramid接收来自骨干网络的C3,C4和C5特征图,一方面以C5特征图为基础构造尺寸为13×16的P6特征图和尺寸为7×8的P7特征图,另一方面将C5上采样和C4合并构造与C4同尺寸的P4,并将P4上采样和C3构造出与C4同尺寸的P3。这样通过Feature Pyramid总共构造出了P3,P4,P5,P6和P7共5个尺寸的特征图,目标检测就在这5个不同尺寸的特征图上完成。


多尺度目标检测头负责对从Feature Pyramid传来P3-7共5个特征图进行检测。在每个特征图上进行目标检测的头都具有相同的结构。每个头包含两个独立的部分,Classification结构和Center-ness结构共享一个部分,其中Classification结构输出大小为H×W×C形状的张量表示在H×W大小的特征图上检测总共C个类别的结果,Center-ness结构输出大小为H×W×1形状的张量表示在H×W的特征图上,是否为目标的中心;Regression独自为一个部分,输出大小为H×W×4形状的张量,表示以该元素为中心的目标到该元素的4个距离,如图10.9所示。图10.9显示了FCOS模型在Regression部分学习的4个距离,分别是中心到左边界的距离,中心到右边界的距离,中心到上边界的距离和中心到下边界的距离。这样FCOS模型就彻底抛弃了锚框IoU的计算,直接以这四个距离进行训练。

图10.8  FCOS目标检测模型

图10.9  FCOS目标外接矩形的表示


由于在torchvision中包含了以ResNet50为骨干网络的FCOS模型,并且还可以加载COCO上预训练的模型,这样创建一个FCOS模型就非常方便,创建方法与上一节创建预训练模型方法相同:
from torchvision.models import detectionmodel = detection.fcos_resnet50_fpn(progress=True,num_classes=3,          pretrained_backbone=True,          trainable_backbone_layers=4)
以上就可以创建一个具有检测3个类别,以带有预训练参数的ResNet50为骨干网络的FCOS模型,对于创建模型时其他参数及其含义可以参考API文档。

数据集制作

由于目标检测模型多样,因此,在训练前对于数据集的构建方法会有所差异。对于torchvision包中提供的所有目标检测模型已经对训练数据的格式进行了统一,因此,只需要把数据按照统一的方式进行构建后,torchvision包内的其它目标检测模型也可以使用。

由于通用目标检测数据集通常较大,不便于进行原理的演示。在这里使用一个样本量较小,类别数较小的目标检测数据集——螺丝螺母检测数据集。螺丝螺母检测数据集是一个开源目标检测数据集,下载地址为:
https://aistudio.baidu.com/aistudio/datasetdetail/6045
同时,该数据集也附于本书的电子资料中。

图10.10  螺丝螺母检测数据集中的样本

螺丝螺母数据集包括413张训练集和10张测试集两部分。图10.10显示了一个带有标注的训练集中的样本,在样本中螺丝和螺母放置于一个白色托盘中,托盘放置在一灰色平台上,螺丝螺母使用矩形框进行标注。以下就以该数据集为例构造用于训练torchvision中模型的数据集。

使用torchvision中的目标检测模型训练自定义数据同样需要对数据集进行封装,并在得到样本的__getitem__()方法中返回一个表示样本元组的数据和标签,以(x, y)表示,其中x是一个范围为0-1的3×H×W的图像张量,y表示图像x的标签,是一个包含‘label’和‘boxes’两个键的字典,‘label’键里以整数张量的形式存储了图像中K个目标的标签值,‘boxes’键里存储了图像中K个对应目标外边矩形框的左上和右下共4个坐标值组成的一个K×4的数字张量,格式如下所示:
#样本标签y的格式:
{'labels': tensor([1, 1, 1, 2, 2, 2, 2, 2]),  'boxes': tensor([[ 711,  233,  844,  506],          [1036,  194, 1206,  459],          [ 958,  406, 1239,  573],          [1142,  194, 1275,  320],          [ 780,  478,  908,  614],          [ 766,  612,  914,  742],          [ 972,  542, 1120,  678],          [ 986,  684, 1120,  820]])}
#以上表明样本中包含8个目标,3个目标的类别为1,5个的目标的类别为2。 按照上述要求,通过继承torch.utils.data.Dataset类创建一个自定义的数据集,实现螺丝螺母数据集的构造:
from pathlib import Pathfrom torchvision.io import read_image,ImageReadModeimport jsonimport torchclass BNDataset(torch.utils.data.Dataset):    def __init__(self, istrain=True,datapath='D:/data/lslm'):   #注意修改数据集路径        self.datadir=Path(datapath)/('train' if istrain else 'test')        self.idxfile=self.datadir/('train.txt' if istrain else 'test.txt')        self.labelnames=['background','bolt','nut']        self.data=self.parseidxfile()    def parseidxfile(self):        lines=open(self.idxfile).readlines()        return [line for line in lines if len(line)>5]

def __getitem__(self, idx): data=self.data[idx].split('\t') x = read_image(str(self.datadir/data[0]),ImageReadMode.RGB)/255.0 labels=[] boxes=[] for i in data[2:]: if len(i)<5: continue r=json.loads(i) labels.append(self.labelnames.index( r['value'])) cords=r['coordinate'] xyxy=cords[0][0],cords[0][1],cords[1][0],cords[1][1] boxes.append(xyxy) y = { 'labels': torch.LongTensor(labels), 'boxes': torch.tensor(boxes).long() } return x, y

def __len__(self): return len(self.data)
以上代码对螺丝螺母数据集以BNDataset为类名进行封装,主要涉及的难点就是标签文件的解析,具体解析过程要结合上述代码和标签文件进行理解。在模型进行训练时,还需要使用把数据集进一步使用DataLoader封装:
def collate_fn(data):    x = [i[0] for i in data]    y = [i[1] for i in data]    return x, y

train_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=True), batch_size=4, shuffle=True, drop_last=True, collate_fn=collate_fn)test_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=True, drop_last=True, collate_fn=collate_fn)
在封装完成后,就可以使用上一节提到的可视化方法,进行样本和标签的可视化,以检查数据集构造的正确性,得到如图10.10所示的结果:
for i, (x, y) in enumerate(train_loader):    labels=[loader.dataset.labelnames[i] for i in y[0]['labels']]colors=[ ('red' if i=='nut' else 'blue') for i in labels]image=draw_bounding_boxes(x[0], y[0]['boxes'],labels=labels,colors=colors,width=5,font_size=50,outtype='CHW')vis.image(image)
#图10.10所示
以上完成了螺丝螺母数据集的构造,能够用于FCOS模型的训练。下面介绍FCOS模型在该数据集上的训练以及模型的评估。

训练与预测

由于torchvision对FCOS模型进行了很好的封装,在准备好数据集后,训练方法与分类和分割网络的训练模式并无太大差异:创建优化器,构造损失函数,对数据集进行多次循环并根据反向传播的梯度进行参数的修正。将训练过程封装为train()函数,调用train()函数进行模型的训练,代码如下:
def train():    model.train()    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001,momentum=0.98)    for epoch in range(5):        for i, (x, y) in enumerate(train_loader):            outs = model(x, y)            loss = outs['classification']+ outs['bbox_ctrness']+outs['bbox_regression']            loss.backward()            optimizer.step()            optimizer.zero_grad()            if i % 10 == 0:                print(epoch, i, loss.item())        torch.save(model, f'./models/tvs{epoch}.model')
train()

#输出结果:

0 0 2.1866644620895386......4 100 0.9854792356491089
以上就是FCOS模型的训练代码,其中train()函数实现了模型的训练,在该函数中,将模型切换到训练模式,创建SGD优化器,总共训练5轮(可根据情况训练更多轮数),每10批打印损失值,经过4轮训练后,损失值从2.18降为了0.98,并且在每轮训练完成后都保存模型。


在模型训练完成后,就可以在测试集上查看和评估模型的检测效果。评估方法实质上与之前介绍的模型的使用方法是相同的,可以参考上一节的内容以便于理解。对模型在测试集上进行运行,并可视化结果,测试代码如下:
def test():    model_load = torch.load('./models/tvs4.model')    model_load.eval()    loader_test = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=False, drop_last=True, collate_fn=collate_fn)






请到「今天看啥」查看全文