专栏名称: AI开发者
AI研习社,雷锋网旗下关注AI开发技巧及技术教程订阅号。
目录
相关文章推荐
AIbase基地  ·  最好的 Manus 复刻项目?GAIA ... ·  21 小时前  
AIbase基地  ·  最好的 Manus 复刻项目?GAIA ... ·  21 小时前  
数据何规  ·  中国AI司法案例报告:纠纷如何奠基AI规则 ·  22 小时前  
数据何规  ·  中国AI司法案例报告:纠纷如何奠基AI规则 ·  22 小时前  
爱可可-爱生活  ·  [LG]《Self-Evolved ... ·  2 天前  
爱可可-爱生活  ·  扩散模型训练动态的功率谱偏置分析理论 ... ·  3 天前  
51好读  ›  专栏  ›  AI开发者

专栏 | 【从零开始学习YOLOv3】3. YOLOv3的数据加载机制和增强方法

AI开发者  · 公众号  · AI  · 2020-02-22 17:12

正文


点击上方“蓝字”关注“AI开发者”


本文来自 @BBuf 的社区专栏 GiantPandaCV ,文末扫码即可订阅专栏。

前言:本文主要讲YOLOv3中数据加载部分,主要解析的代码在utils/datasets.py文件中。通过对数据组织、加载、处理部分代码进行解读,能帮助我们更快地理解YOLOv3所要求的数据输出要求,也将有利于对之后训练部分代码进行理解。

1. 标注格式

在上一篇 【从零开始学习YOLOv3】2. YOLOv3中的代码配置和数据集构建 中,使用到了 voc_label.py ,其作用是将xml文件转成txt文件格式,具体文件如下:

# class id, x, y, w, h
00.86041666666666660.54038997214484690.0583333333333333340.055710306406685235

其中的x,y 的意义是归一化以后的框的中心坐标,w,h是归一化后的框的宽和高。

具体的归一化方式为:

def convert(size, box):
   '''
   size是图片的长和宽
   box是xmin,xmax,ymin,ymax坐标值
   '''

   dw = 1. / (size[0])
   dh = 1. / (size[1])
   # 得到长和宽的缩放比
   x = (box[0] + box[1])/2.0  
   y = (box[2] + box[3])/2.0  
   w = box[1] - box[0]
   h = box[3] - box[2]
   # 分别计算中心点坐标,框的宽和高
   x = x * dw
   w = w * dw
   y = y * dh
   h = h * dh
   # 按照图片长和宽进行归一化
   return (x,y,w,h)

可以看出,归一化都是相对于图片的宽和高进行归一化的。

2. 调用

下边是train.py文件中的有关数据的调用:

# Dataset
dataset = LoadImagesAndLabels(train_path, img_size, batch_size,
                             augment=True,
                             hyp=hyp,  # augmentation hyperparameters
                             rect=opt.rect,  # rectangular training
                             cache_labels=True,
                             cache_images=opt.cache_images)

batch_size = min(batch_size, len(dataset))

# 使用多少个线程加载数据集
nw = min([os.cpu_count(), batch_size if batch_size > 1else0, 1])

dataloader = DataLoader(dataset,
                       batch_size=batch_size,
                       num_workers=nw,
                       shuffle=not opt.rect,
                       # Shuffle=True
                       #unless rectangular training is used
                       pin_memory=True,
                       collate_fn=dataset.collate_fn)

在pytorch中,数据集加载主要是重构datasets类,然后再使用dataloader中加载dataset,就构建好了数据部分。

下面是一个简单的使用模板:

import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# 根据自己的数据集格式进行重构
class MyDataset(Dataset):
   def __init__(self):
       #下载数据、初始化数据,都可以在这里完成
       xy = np.loadtxt('label.txt', delimiter=',', dtype=np.float32)
       # 使用numpy读取数据
       self.x_data = torch.from_numpy(xy[:, 0:-1])
       self.y_data = torch.from_numpy(xy[:, [-1]])
       self.len = xy.shape[0]
   
   def __getitem__(self, index):
       # dataloader中使用该方法,通过index进行访问
       return self.x_data[index], self.y_data[index]

   def __len__(self):
       # 查询数据集中数量,可以通过len(mydataset)得到
       return self.len

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
myDataset = MyDataset()

# 构建dataloader
train_loader = DataLoader(dataset=myDataset,
                         batch_size=32,
                         shuffle=True)

for epoch in range(2):
   for i, data in enumerate(train_loader2):
       # 将数据从 train_loader 中读出来,一次读取的样本数是32个
       inputs, labels = data
       # 将这些数据转换成Variable类型
       inputs, labels = Variable(inputs), Variable(labels)
# 模型训练...

通过以上模板就能大致了解pytorch中的数据加载机制,下面开始介绍YOLOv3中的数据加载。

3. YOLOv3中的数据加载

下面解析的是LoadImagesAndLabels类中的几个主要的函数:

3.1 init函数

init函数中包含了大部分需要处理的数据

class LoadImagesAndLabels(Dataset):  # for training/testing
   def __init__(self,
                path,
                img_size=416,
                batch_size=16,
                augment=False,
                hyp=None,
                rect=False,
                image_weights=False,
                cache_labels=False,
                cache_images=False)
:

       path = str(Path(path))  # os-agnostic
       assert os.path.isfile(path), 'File not found %s. See %s' % (path,
                                                                   help_url)
       with open(path, 'r') as f:
           self.img_files = [
               x.replace('/', os.sep)
               for x in f.read().splitlines()  # os-agnostic
               if os.path.splitext(x)[-1].lower() in img_formats
           ]
       # img_files是一个list,保存的是图片的路径

       n = len(self.img_files)
       assert n > 0, 'No images found in %s. See %s' % (path, help_url)
       bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
       # 如果n=10, batch=2, bi=[0,0,1,1,2,2,3,3,4,4]
       nb = bi[-1] + 1  # 最多有多少个batch

       self.n = n
       self.batch = bi  # 图片的batch索引,代表第几个batch的图片
       self.img_size = img_size
       self.augment = augment
       self.hyp = hyp
       self.image_weights = image_weights # 是否选择根据权重进行采样
       self.rect = Falseif image_weights else rect
       # 如果选择根据权重进行采样,将无法使用矩形训练:
       # 具体内容见下文

       # 标签文件是通过images替换为labels, .jpg替换为.txt得到的。
       self.label_files = [
           x.replace('images',
                     'labels').replace(os.path.splitext(x)[-1], '.txt')
           for x in self.img_files
       ]

       # 矩形训练具体内容见下文解析
       if self.rect:
           # 获取图片的长和宽 (wh)
           sp = path.replace('.txt', '.shapes')
           # 字符串替换
           # shapefile path
           try:
               with open(sp, 'r') as f:  # 读取shape文件
                   s = [x.split() for x in f.read().splitlines()]
                   assert len(s) == n, 'Shapefile out of sync'
           except:
               s = [
                   exif_size(Image.open(f))
                   for f in tqdm(self.img_files, desc='Reading image shapes')
               ]
               np.savetxt(sp, s, fmt='%g')  # overwrites existing (if any)

           # 根据长宽比进行排序
           s = np.array(s, dtype=np.float64)
           ar = s[:, 1] / s[:, 0]  # aspect ratio
           i = ar.argsort()

           # 根据顺序重排顺序
           self.img_files = [self.img_files[i] for i in i]
           self.label_files = [self.label_files[i] for i in i]
           self.shapes = s[i]  # wh
           ar = ar[i]

           # 设置训练的图片形状
           shapes = [[1, 1]] * nb
           for i in range(nb):
               ari = ar[bi == i]
               mini, maxi = ari.min(), ari.max()
               if maxi < 1:
                   shapes[i] = [maxi, 1]
               elif mini > 1:
                   shapes[i] = [1, 1 / mini]

           self.batch_shapes = np.ceil(
               np.array(shapes) * img_size / 32.).astype(np.int) * 32

       # 预载标签
       # weighted CE 训练时需要这个步骤
       # 否则无法按照权重进行采样
       self.imgs = [None] * n
       self.labels = [None] * n
       if cache_labels or image_weights:  # cache labels for faster training
           self.labels = [np.zeros((0, 5))] * n
           extract_bounding_boxes = False
           create_datasubset = False
           pbar = tqdm(self.label_files, desc='Caching labels')
           nm, nf, ne, ns, nd = 0, 0, 0, 0, 0  # number missing, found, empty, datasubset, duplicate
           for i, file in enumerate(pbar):
               try:
                   # 读取每个文件内容
                   with open(file, 'r') as f:
                       l = np.array(
                           [x.split() for x in f.read().splitlines()],
                           dtype=np.float32)
               except:
                   nm += 1  # print('missing labels for image %s' % self.img_files[i])  # file missing
                   continue

               if l.shape[0]:
                   # 判断文件内容是否符合要求
                   # 所有的值需要>0, <1, 一共5列
                   assert l.shape[1] == 5, '> 5 label columns: %s' % file
                   assert (l >= 0).all(), 'negative labels: %s' % file
                   assert (l[:, 1:] <= 1).all(
                   ), 'non-normalized or out of bounds coordinate labels: %s' % file
                   if np.unique(
                           l, axis=0).shape[0] < l.shape[0]:  # duplicate rows
                       nd += 1  # print('WARNING: duplicate rows in %s' % self.label_files[i])  # duplicate rows

                   self.labels[i] = l
                   nf += 1  # file found

                   # 创建一个小型的数据集进行试验
                   if create_datasubset and ns < 1E4:
                       if ns == 0:
                           create_folder(path='./datasubset')
                           os.makedirs('./datasubset/images')
                       exclude_classes = 43
                       if exclude_classes notin l[:, 0]:
                           ns += 1
                           # shutil.copy(src=self.img_files[i], dst='./datasubset/images/')  # copy image
                           with open('./datasubset/images.txt', 'a') as f:
                               f.write(self.img_files[i] + '\n')

                   # 为两阶段分类器提取目标检测的检测框
                   # 默认开关是关掉的,不是很理解
                   if extract_bounding_boxes:
                       p = Path(self.img_files[i])
                       img = cv2.imread(str(p))
                       h, w = img.shape[:2]
                       for j, x in enumerate(l):
                           f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent,
                                                             os.sep, os.sep,
                                                             x[0], j, p.name)
                           ifnot os.path.exists(Path(f).parent):
                               os.makedirs(Path(f).parent)
                               # make new output folder

                           b = x[1:] * np.array([w, h, w, h])  # box
                           b[2:] = b[2:].max()  # rectangle to square
                           b[2:] = b[2:] * 1.3 + 30  # pad

                           b = xywh2xyxy(b.reshape(-1,4)).ravel().astype(np.int)

                           b[[0,2]] = np.clip(b[[0, 2]], 0,w)  # clip boxes outside of image
                           b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
                           assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
               else:
                   ne += 1

               pbar.desc = 'Caching labels (%g found, %g missing, %g empty, %g duplicate, for %g images)'
               % (nf, nm, ne, nd, n) # 统计发现,丢失,空,重复标签的数量。






请到「今天看啥」查看全文