点击上方
“
小白学视觉
”,选择加"
星标
"或“
置顶
”
重磅干货,第一时间送达
![](http://mmbiz.qpic.cn/sz_mmbiz_png/4AqSEnNUeribVWrGvYTZ11C8MiahOe1YRauVzKqzta2cOaUDltlRriaaSq7kR8S7pXibqqAn54knQ5HMDDORgVEBTA/0?wx_fmt=png&from=appmsg)
DINO模型输出的狗冲刺
无标签自蒸馏(DINO)
《
从几个“补丁”中重建完整图像 | 构建可扩展学习器的掩模自编码器
》这边文章讲了如何构建可扩展学习器,这是我对视觉变换器系列的继续,其中我解释了最重要的架构及其从零开始的实现。
自监督学习
自监督学习(SSL)是一种机器学习类型,模型通过无需手动标记的示例来学习理解数据。相反,它从数据本身生成其监督信号。当标记数据有限且获取成本高昂时,这种方法非常有益。
在SSL中,学习过程涉及创建任务,其中输入数据可以用来预测数据本身的某些部分。
常见的技术包括:
DINO模型
DINO(无标签蒸馏)模型是一种应用于视觉变换器(ViTs)的尖端自监督学习方法。它代表了计算机视觉领域的一个重大进步,使模型能够在不需要任何标记数据的情况下学习有效的图像表示。由Facebook AI Research(FAIR)的研究人员开发,DINO利用学生-教师框架和创新的训练技术,在各种视觉任务上取得了卓越的性能。
学生-教师网络
在DINO模型中,学生-教师网络是实现无需标记数据的自监督学习的核心机制。这个框架涉及两个网络:学生网络和教师网络。两个网络都是视觉变换器,它们被设计用来通过将图像处理为序列块来处理图像,类似于变换器处理文本序列的方式。
学生网络的任务是从输入图像中学习生成有意义的表示。另一方面,教师网络提供目标表示,学生网络旨在匹配这些表示。教师网络不是一个静态实体;它通过逐渐整合学生网络的参数随时间演变。这是通过一种称为指数移动平均的技术完成的,其中教师的参数被更新为其当前参数和学生参数的加权平均值。
目标是最小化学生表示和教师表示之间的差异,这些表示是针对相同增强图像视图的。这通常是通过使用一个损失函数来实现的,该函数鼓励学生和教师输出之间的对齐,同时确保不同图像的表示保持不同。
通过根据学生网络的学习进度不断更新教师网络,并训练学生网络以匹配教师的输出,DINO有效地利用了两个网络的优势。教师网络为学生提供了稳定和一致的目标,而学生网络推动了学习过程。这种协作设置允许模型在无需手动标签的情况下从数据中学习强大和不变的特征,从而实现有效的自监督学习。
学生和教师的增强输入
在DINO模型中,X1和X2(见上图)指的是同一原始图像X的不同增强视图。这些视图分别用作学生和教师网络的输入。目标是让学生网络学习在这些增强下产生一致的表示。
学生和教师模型根据以下策略接收不同的增强:
我们将如何为参数图像定义这些增强,这些图像包含我们在训练期间想要转换的一批图像。
def global_augment(images):
global_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return torch.stack([global_transform(img) for img in images])
def multiple_local_augments(images, num_crops=6):
size = 96
local_transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.05, 0.4)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return torch.stack([local_transform(img) for img in images])
蒸馏损失:
在这里,我们希望使用某种距离度量来计算学生输出和教师输出之间的损失。我们这样做:
-
获取教师预测输出的中心化Softmax,然后应用锐化。
-
获取学生的Softmax预测,然后应用锐化。
def
distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
teacher_output = teacher_output.detach()
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)
student_probs = F.log_softmax(student_output / tau_s, dim=1)
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss
中心化:
中心化教师的输出确保学生模型更多地关注教师输出分布中最显著的特征或区别。通过中心化分布,鼓励学生更多地关注对准确预测至关重要的显著特征,而不是受数据中的变化或偏差的影响。这有助于更有效的知识传递,并可能导致学生模型的性能提高。
锐化:
锐化涉及放大数据分布中的特定特征,旨在强调教师模型突出的区分。这个过程使学生模型能够专注于学习教师预测中存在的复杂细节,这对于在数据集上准确复制其输出至关重要。
训练DINO模型:
阐明DINO伪代码的图像,取自官方论文
有3个重要的步骤需要强调:
1. 获取学生和教师架构的不同输入(x1,x2)的增强。
2. 我们之前讨论的蒸馏损失函数,注意它是如何计算不同增强输入的架构的蒸馏损失的,即gs({x1, x2})和gt({x1, x2})。
3. 更新(a)学生参数(b)教师参数和(c)中心。这里的关键是我们对更新教师参数执行指数移动平均更新。
DINO模型
class DINO(nn.Module):
def __init__(self, student_arch: Callable, teacher_arch: Callable, device: torch.device):
"""
Args:
student_arch (nn.Module): ViT Network for student_arch
teacher_arch (nn.Module): ViT Network for teacher_arch
device: torch.device ('cuda' or 'cpu')
"""
super(DINO, self).__init__()
self.student = student_arch().to(device)
self.teacher = teacher_arch().to(device)
self.teacher.load_state_dict(self.student.state_dict())
self.register_buffer('center', torch.zeros(1, student_arch().output_dim))
for param in self.teacher.parameters():
param.requires_grad = False
@staticmethod
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
teacher_output = teacher_output.detach()
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)
student_probs = F.log_softmax(student_output / tau_s, dim=1)
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss
def teacher_update(self, beta: float):
for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))
为了更新教师的参数,我们使用论文中提出公式,即gt.param = gt.param*beta + gs.param*(1 — beta),其中beta是移动平均衰减,gt、gs分别是相应的教师和学生架构。
进一步,我们在__init__下看到,教师的参数已设置为“required_grads = False”,因为我们不希望在反向传播期间更新它们,而是应用移动平均更新。
此外,在PyTorch中将变量初始化为bugger是一种常见方法,用于将其保持在
梯度图之外,并不参与反向传播。
Dino模型进一步需要如下调用
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dino = DINO(ViT(), ViT(), device)
在这里,我们传递学生和教师架构,这不过是标准的视觉变换器,即ViT-B/16或ViT-L/16,正如第一篇论文中提出的。
最终训练
现在可以将整个实现放入训练循环中,正如论文中提出的。
def train_dino(dino: DINO,
data_loader: DataLoader,
optimizer: Optimizer,
device: torch.device,
num_epochs,
tps=0.9,
tpt= 0.04,
beta= 0.9,
m= 0.9,
):
"""
Args:
dino: DINO Module
data_loader (nn.Module): Dataloader for training
optimizer (nn.optimizer): Optimizer for optimization (SGD etc.)
defice (torch.device): 'cuda', 'cpu'
num_epochs: Number of Epochs
tps (float): tau for sharpening student logits
tpt: for sharpening teacher logits
beta (float): moving average decay
m (float): center moveing average decay
"""
for epoch in range(num_epochs):
print(f"Epoch: {epoch+1}/{len(num_epochs)}")
for x in data_loader:
x1, x2 = global_augment(x), multiple_local_augments(x)
student_output1, student_output2 = dino.student(x1.to(device)), dino.student(x2.to(device))
with torch.no_grad():
teacher_output1, teacher_output2 = dino.teacher(x1.to(device)), dino.teacher(x2.to(device))
loss = (dino.distillation_loss(teacher_output1, student_output2, dino.center, tps, tpt) +
dino.distillation_loss(teacher_output2, student_output1, dino.center, tps, tpt)) / 2
optimizer.zero_grad()
loss.backward()
optimizer.step()
dino.teacher_update(beta)
with torch.no_grad():
dino.center = m * dino.center + (1 - m) * torch.cat([teacher_output1, teacher_output2], dim=0).mean(dim=0)
-
我们用不同的全局和局部增强计算x1和x2。
-
之后,我们根据论文中提出的,为学生和教师模型获取输出,回想上面的算法循环图。
-
在这里,我们将torch设置为no_grad()函数,以确保教师的参数不会通过反向传播更新。
-
最后,我们再次根据论文中提出的方法计算蒸馏损失。
-
在蒸馏损失中,我们首先中心化教师模型的输出,这样学生模型就不容易崩溃,也不会只学习不重要的特征,或者比另一个特征更多地学习一个特征,而是专注于从教师模型中学习最独特和潜在的特征。
-
然后我们锐化特征,以便在计算损失时,我们现在能够比较两个特征(学生和教师的)具有非常不同的数据分布,这意味着锐化后,更重要的特征会被锐化,而不太重要的特征则不会,这将创建一个更独特的特征图,使学生更容易学习。
-
然后我们执行反向传播并执行optimizer.step(),更新学生模型并通过之前实现的指数移动平均更新教师网络。
-
作为最后一步,我们将再次将torch设置为no_grad()并通过移动平均更新中心。我们根据教师的输出更新中心,因此它与训练过程中输出数据分布的变化保持一致。