专栏名称: 小白学视觉
本公众号主要介绍机器视觉基础知识和新闻,以及在学习机器视觉时遇到的各种纠结和坑的心路历程。
目录
相关文章推荐
军武次位面  ·  值!一眼爱上雾面皮衣,只要118元显高级! ·  2 天前  
中国兵器工业集团  ·  北方公司北方国际蒙古选煤厂项目荣获“蒙古最优 ... ·  2 天前  
51好读  ›  专栏  ›  小白学视觉

再也不用担心过拟合的问题了

小白学视觉  · 公众号  ·  · 2025-02-06 18:05

正文

点击上方 小白学视觉 ”,选择加" 星标 "或“ 置顶

重磅干货,第一时间送达

作者 | Sean Benhur J

编译 | ronghuaiyang
转自 | AI公园

使用SAM(锐度感知最小化),优化到损失的最平坦的最小值的地方,增强泛化能力。

论文:https://arxiv.org/pdf/2010.01412.pdf

代码:https://github.com/moskomule/sam.pytorch

动机来自先前的工作,在此基础上,我们提出了一种新的、有效的方法来同时减小损失值和损失的锐度。具体来说,在我们的处理过程中,进行锐度感知最小化(SAM),在领域内寻找具有均匀的低损失值的参数。这个公式产生了一个最小-最大优化问题,在这个问题上梯度下降可以有效地执行。我们提出的实证结果表明,SAM在各种基准数据集上都改善了的模型泛化。

在深度学习中,我们使用SGD/Adam等优化算法在我们的模型中实现收敛,从而找到全局最小值,即训练数据集中损失较低的点。但等几种研究表明,许多网络可以很容易地记住训练数据并有能力随时overfit,为了防止这个问题,增强泛化能力,谷歌研究人员发表了一篇新论文叫做Sharpness Awareness Minimization,在CIFAR10上以及其他的数据集上达到了最先进的结果。

在本文中,我们将看看为什么SAM可以实现更好的泛化,以及我们如何在Pytorch中实现SAM。

SAM的原理是什么?

在梯度下降或任何其他优化算法中,我们的目标是找到一个具有低损失值的参数。但是,与其他常规的优化方法相比,SAM实现了更好的泛化,它将重点放在领域内寻找具有均匀的低损失值的参数(而不是只有参数本身具有低损失值)上。

由于计算邻域参数而不是计算单个参数,损失超平面比其他优化方法更平坦,这反过来增强了模型的泛化。

(左))用SGD训练的ResNet收敛到的一个尖锐的最小值。(右)用SAM训练的相同的ResNet收敛到的一个平坦的最小值。

注意:SAM不是一个新的优化器,它与其他常见的优化器一起使用,比如SGD/Adam。

在Pytorch中实现SAM

在Pytorch中实现SAM非常简单和直接

import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is Nonecontinue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is Nonecontinue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()


    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if






请到「今天看啥」查看全文