由于目标检测模型多样,因此,在训练前对于数据集的构建方法会有所差异。对于torchvision包中提供的所有目标检测模型已经对训练数据的格式进行了统一,因此,只需要把数据按照统一的方式进行构建后,torchvision包内的其它目标检测模型也可以使用。
由于通用目标检测数据集通常较大,不便于进行原理的演示。在这里使用一个样本量较小,类别数较小的目标检测数据集——螺丝螺母检测数据集。螺丝螺母检测数据集是一个开源目标检测数据集,下载地址为:
https://aistudio.baidu.com/aistudio/datasetdetail/6045
螺丝螺母数据集包括413张训练集和10张测试集两部分。图10.10显示了一个带有标注的训练集中的样本,在样本中螺丝和螺母放置于一个白色托盘中,托盘放置在一灰色平台上,螺丝螺母使用矩形框进行标注。以下就以该数据集为例构造用于训练torchvision中模型的数据集。
使用torchvision中的目标检测模型训练自定义数据同样需要对数据集进行封装,并在得到样本的__getitem__()方法中返回一个表示样本元组的数据和标签,以(x, y)表示,其中x是一个范围为0-1的3×H×W的图像张量,y表示图像x的标签,是一个包含‘label’和‘boxes’两个键的字典,‘label’键里以整数张量的形式存储了图像中K个目标的标签值,‘boxes’键里存储了图像中K个对应目标外边矩形框的左上和右下共4个坐标值组成的一个K×4的数字张量,格式如下所示:
{'labels': tensor([1, 1, 1, 2, 2, 2, 2, 2]),
'boxes': tensor([[ 711, 233, 844, 506],
[1036, 194, 1206, 459],
[ 958, 406, 1239, 573],
[1142, 194, 1275, 320],
[ 780, 478, 908, 614],
[ 766, 612, 914, 742],
[ 972, 542, 1120, 678],
[ 986, 684, 1120, 820]])}
#以上表明样本中包含8个目标,3个目标的类别为1,5个的目标的类别为2。
按照上述要求,通过继承torch.utils.data.Dataset类创建一个自定义的数据集,实现螺丝螺母数据集的构造:
from pathlib import Path
from torchvision.io import read_image,ImageReadMode
import json
import torch
class BNDataset(torch.utils.data.Dataset):
def __init__(self, istrain=True,datapath='D:/data/lslm'): #注意修改数据集路径
self.datadir=Path(datapath)/('train' if istrain else 'test')
self.idxfile=self.datadir/('train.txt' if istrain else 'test.txt')
self.labelnames=['background','bolt','nut']
self.data=self.parseidxfile()
def parseidxfile(self):
lines=open(self.idxfile).readlines()
return [line for line in lines if len(line)>5]
def __getitem__(self, idx):
data=self.data[idx].split('\t')
x = read_image(str(self.datadir/data[0]),ImageReadMode.RGB)/255.0
labels=[]
boxes=[]
for i in data[2:]:
if len(i)<5: continue
r=json.loads(i)
labels.append(self.labelnames.index( r['value']))
cords=r['coordinate']
xyxy=cords[0][0],cords[0][1],cords[1][0],cords[1][1]
boxes.append(xyxy)
y = {
'labels': torch.LongTensor(labels),
'boxes': torch.tensor(boxes).long()
}
return x, y
def __len__(self):
return len(self.data)
以上代码对螺丝螺母数据集以BNDataset为类名进行封装,主要涉及的难点就是标签文件的解析,具体解析过程要结合上述代码和标签文件进行理解。在模型进行训练时,还需要使用把数据集进一步使用DataLoader封装:
def collate_fn(data):
x = [i[0] for i in data]
y = [i[1] for i in data]
return x, y
train_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=True), batch_size=4, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=True, drop_last=True, collate_fn=collate_fn)
在封装完成后,就可以使用上一节提到的可视化方法,进行样本和标签的可视化,以检查数据集构造的正确性,得到如图10.10所示的结果:
for i, (x, y) in enumerate(train_loader):
labels=[loader.dataset.labelnames[i] for i in y[0]['labels']]
colors=[ ('red' if i=='nut' else 'blue') for i in labels]
image=draw_bounding_boxes(x[0], y[0]['boxes'],labels=labels,colors=colors,width=5,font_size=50,outtype='CHW')
vis.image(image)
以上完成了螺丝螺母数据集的构造,能够用于FCOS模型的训练。下面介绍FCOS模型在该数据集上的训练以及模型的评估。
由于torchvision对FCOS模型进行了很好的封装,在准备好数据集后,训练方法与分类和分割网络的训练模式并无太大差异:创建优化器,构造损失函数,对数据集进行多次循环并根据反向传播的梯度进行参数的修正。将训练过程封装为train()函数,调用train()函数进行模型的训练,代码如下:
def train():
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001,momentum=0.98)
for epoch in range(5):
for i, (x, y) in enumerate(train_loader):
outs = model(x, y)
loss = outs['classification']+ outs['bbox_ctrness']+outs['bbox_regression']
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
print(epoch, i, loss.item())
torch.save(model, f'./models/tvs{epoch}.model')
train()
#输出结果:
0 0 2.1866644620895386
...
...
4 100 0.9854792356491089
以上就是FCOS模型的训练代码,其中train()函数实现了模型的训练,在该函数中,将模型切换到训练模式,创建SGD优化器,总共训练5轮(可根据情况训练更多轮数),每10批打印损失值,经过4轮训练后,损失值从2.18降为了0.98,并且在每轮训练完成后都保存模型。
在模型训练完成后,就可以在测试集上查看和评估模型的检测效果。评估方法实质上与之前介绍的模型的使用方法是相同的,可以参考上一节的内容以便于理解。对模型在测试集上进行运行,并可视化结果,测试代码如下:
def test():
model_load = torch.load('./models/tvs4.model')
model_load.eval()
loader_test = torch.utils.data.DataLoader(dataset=BNDataset(istrain=False), batch_size=1, shuffle=False, drop_last=True, collate_fn=collate_fn)