专栏名称: 机器之心
专业的人工智能媒体和产业服务平台
目录
相关文章推荐
机器之心  ·  创造历史!DeepSeek超越ChatGPT ... ·  10 小时前  
黄建同学  ·  OpenAI的CUA和Antropic的MC ... ·  2 天前  
新智元  ·  颠覆LLM格局!AI2新模型OLMo2,训练 ... ·  3 天前  
爱可可-爱生活  ·  【[46星]Humanity's Last ... ·  3 天前  
机器学习研究组订阅  ·  ICLR ... ·  4 天前  
51好读  ›  专栏  ›  机器之心

资源 | 注意迁移的PyTorch实现

机器之心  · 公众号  · AI  · 2017-02-06 12:10

正文

选自Github

作者:szagoruyko

机器之心编译

参与:赵华龙、吴攀


本项目是论文《要更加注重注意力:通过注意迁移技术提升卷积神经网络的性能(Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer)》PyTorch 实现。点击文末「阅读原文」可查阅原论文。


项目地址:https://github.com/szagoruyko/attention-transfer


这篇论文已经提交给了 ICLR 2017 会议,正在 review 状态:https://openreview.net/forum?id=Sks9_ajex



到目前为止该代码库里的内容包括:


  • CIFAR-10 实验的基于激活技术的 AT 代码

  • ImageNet 实验的代码(ResNet-18-ResNet-34 student-teacher)


即将上线:


  • 基于梯度的 AT

  • 场景和基于 CUB 激活的 AT 代码

  • 预训练的基于激活的 AT ResNet-18


代码使用 PyTorch。原始的实验是用 torch-autograd 做的,我们目前已经验证了 CIFAR-10 实验结果能够完全在 PyTorch 中复现,而且目前正在针对 ImageNet 做类似的工作(由于超参数的原因,PyTorch 的结果有一点点变差)


引用:


@article{Zagoruyko2016AT,    author = {Sergey Zagoruyko and Nikos Komodakis},    title = {Paying More Attention to Attention: Improving the Performance of             Convolutional Neural Networks via Attention Transfer},    url = {https://arxiv.org/abs/1612.03928},    year = {2016}}


要求


先安装 PyTorch,再安装 torchnet:


git clone https://github.com/pytorch/tnt cd tnt python setup.py install


安装 OpenCV 以及 Python 支持包,以及带有 OpenCV 变换的 torchvision:


git clone https://github.com/szagoruyko/vision cd vision; git checkout opencv python setup.py install


最后,安装其他的 Python 包:


pip install -r requirements.txt


实验


CIFAR-10


这一节讲述如何得到本文中第一个表里的那些结果。

首先,训练老师:


python cifar.py --save logs/resnet_40_1_teacher --depth 40 --width 1 python cifar.py --save logs/resnet_16_2_teacher --depth 16 --width 2 python cifar.py --save logs/resnet_40_2_teacher --depth 40 --width 2


用基于激活的 AT 来训练:


python cifar.py --save logs/at_16_1_16_2 --teacher_id resnet_16_2_teacher --beta 1e+3


用 KD 来训练:


python cifar.py --save logs/kd_16_1_16_2 --teacher_id resnet_16_2_teacher --alpha 0.9


我们下一步计划增加带有 beta 衰退的 AT+KD 来得到最优的知识转换结果。


ImageNet


预训练模型


我们提供带有基于激活 AT 的 ResNet-18 预训练模型:



从头开始训练


下载 ResNet-34 的预训练权值(functional-zoo 里有更多介绍):



wget https://s3.amazonaws.com/pytorch/h5models/resnet-34-export.hkl


根据 fb.resnet.torch 准备数据,然后进行训练(比如使用 2 个 GPU):


python imagenet.py --imagenetpath ~/ILSVRC2012 --depth 18 --width 1 \                   --teacher_params resnet-34-export.hkl --gpu_id 0,1 --ngpu 2 \                   --beta 1e+3




©本文为机器之心编译,转载请联系本公众号获得授权

✄------------------------------------------------

加入机器之心(全职记者/实习生):[email protected]

投稿或寻求报道:[email protected]

广告&商务合作:[email protected]