专栏名称: 计算机视觉工坊
专注于计算机视觉、VSLAM、目标检测、语义分割、自动驾驶、深度学习、AI芯片、产品落地等技术干货及前沿paper分享。这是一个由多个大厂算法研究人员和知名高校博士创立的平台,我们坚持工坊精神,做最有价值的事~
目录
相关文章推荐
白鲸出海  ·  Canva将“印度式低价策略”推向更多市场, ... ·  7 小时前  
阿里开发者  ·  如何在IDE里使用DeepSeek-V3 ... ·  昨天  
白鲸出海  ·  2024年AI投资Top5机构出炉,又一AI ... ·  2 天前  
白鲸出海  ·  开年刚一周,GTC2025(Shenzhen ... ·  2 天前  
百度智能云  ·  首日1.5万后,百度智能云千帆助力DeepS ... ·  3 天前  
51好读  ›  专栏  ›  计算机视觉工坊

仅7M参数!Mamba赋能强化学习!笔记本都能训练!

计算机视觉工坊  · 公众号  ·  · 2024-10-17 07:00

正文

点击下方 卡片 ,关注 「3D视觉工坊」 公众号
选择 星标 ,干货第一时间送达

来源:计算机视觉工坊

添加小助理:cv3d008,备注:方向+学校/公司+昵称,拉你入群。文末附3D视觉行业细分群。

扫描下方二维码,加入「 3D视觉从入门到精通 」知识星球 ,星球内凝聚了众多3D视觉实战问题,以及各个模块的学习资料: 近20门秘制视频课程 最新顶会论文 、计算机视觉书籍 优质3D视觉算法源码 等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入!

仅7M参数!Mamba赋能强化学习!笔记本都能训练!

0. 论文信息

标题:Drama: Mamba-Enabled Model-Based Reinforcement Learning Is Sample and Parameter Efficient

作者:Wenlong Wang, Ivana Dusparic, Yucheng Shi, Ke Zhang, Vinny Cahill

机构:Trinity College Dublin

原文链接:https://arxiv.org/abs/2410.08893

代码链接:https://github.com/realwenlongwang/drama

1. 导读

基于模型的强化学习为困扰大多数无模型强化学习算法的数据低效问题提供了解决方案。然而,学习一个健壮的世界模型通常需要复杂和深入的架构,这对于计算和训练来说是昂贵的。在世界模型中,动力学模型对于准确预测尤为重要,已经探索了各种动力学模型架构,每种架构都有其自身的挑战。目前,基于递归神经网络(RNN)的世界模型面临着诸如消失梯度和难以有效捕捉长期依赖性的问题。相比之下,使用变压器会遇到众所周知的自我关注机制问题,其中内存和计算复杂性都随着O(n2),与n表示序列长度。
为了解决这些挑战,我们提出了基于状态空间模型(SSM)的世界模型,特别是基于Mamba的世界模型,其实现了O(n)同时有效地捕捉长期依赖性并促进更长训练序列的有效使用。我们还引入了一种新的采样方法,以减轻在训练的早期阶段由不正确的世界模型引起的次优性,将其与前述技术相结合,以实现与仅使用700万可训练参数世界模型的其他基于模型的RL算法相当的归一化分数。这种模式易于使用,可以在现成的笔记本电脑上进行培训。

2. 引言

深度强化学习(RL)在诸如围棋、Dota、Atari游戏和MuJoCo等一系列具有挑战性的应用中取得了显著成就。然而,训练能够解决复杂任务的策略往往需要数百万次的交互,这在实践中可能并不可行,并成为了现实应用的一大障碍。因此,提高样本效率已成为强化学习算法开发中最关键的目标之一。

世界模型通过一种自动生成过程来产生用于训练强化学习智能体的人工样本,从而在提高样本效率方面展现出巨大潜力,这种方法被称为基于模型的强化学习。在这种方法中,利用序列模型通过交互数据学习环境动态,使智能体能够在由所得动态模型生成的人工经验上进行训练,而不是依赖真实世界的交互。这种方法将问题从直接使用真实样本改进策略(这种方法样本效率低下)转变为提高世界模型的准确性以匹配真实环境(这种方法样本效率更高)。然而,基于模型的强化学习面临着一个众所周知的挑战:当模型因观察到的样本有限而不准确时,尤其是在训练初期,学习到的策略可能会偏向于次优行为,并且模型错误的检测既困难又几乎不可能实现。

在序列建模中,线性复杂度是非常理想的,因为它允许模型在不大幅增加计算和内存资源的情况下高效地处理更长的序列。这对于训练世界模型尤为重要,因为世界模型需要高效的序列建模来模拟长时间范围内的复杂环境。循环神经网络(RNNs),特别是像长短期记忆(LSTM)和门控循环单元(GRU)这样的高级变体,具有线性复杂度,使其在计算上对于这一任务更具吸引力。然而,RNNs仍然难以解决梯度消失问题,并且在捕捉长期依赖关系方面效率低下。最近,在自然语言处理领域占据主导地位的Transformer架构,在图像处理、离线强化学习等领域的开创性工作之后,迅速在这些领域获得了广泛认可。Transformer结构在基于模型的强化学习中也证明了其有效性。然而,Transformer存在内存和计算复杂度都随序列长度n呈O(n²)增长的问题,这对于需要长序列来模拟复杂环境的世界模型来说是一个挑战。 推荐课程: 国内首个面向工业级实战的点云处理课程

目前,状态空间模型(SSMs)因其能够以线性复杂度高效处理长序列问题而备受关注。在SSMs中,Mamba在各种领域(包括自然语言处理、计算机视觉和离线强化学习)中已成为基于Transformer架构的有力竞争对手。将Mamba架构应用于基于模型的强化学习尤其具有吸引力,因为它具有随序列长度线性扩展的内存和计算规模,同时能够有效捕捉长期依赖关系。此外,高效捕捉环境动态可以减少在不准确的世界模型中学习到行为策略的可能性,我们也通过引入一种新颖的动态频率采样方法来解决这一问题。

3. 主要贡献

在本文中,我们做出了三项主要贡献:

• 我们介绍了DRAMA,这是第一个基于Mamba SSM构建的基于模型的强化学习智能体,以Mamba-2作为其架构的核心。我们在Atari100k基准上评估了DRAMA,证明其性能与其他最先进的算法相当,而使用的世界模型仅包含700万个可训练参数。

• 此外,我们比较了Mamba-1和Mamba-2的性能,证明了在Atari100k基准中,尽管Mamba-2为了提升训练效率而略微限制了表达能力,但作为动态模型仍取得了更优的结果。

• 最后,我们提出了一种新颖且直接的采样方法,即基于动态频率的采样(DFS),以缓解不完善动态模型带来的挑战。

4. 方法

Drama世界模型架构。从序列索引i开始,原始游戏帧被编码为zi,并与动作ai结合,作为输入传递给Mamba模块。输入通道维度被头维度p除,以生成确定性循环状态di。该循环状态di用于预测下一个嵌入ˆzi+1、奖励ˆri和终止标志ˆti,这些代表了基于当前帧和动作的结果。解码器是从编码后的嵌入zi而不是从预测的嵌入ˆzi中重构原始帧。Mamba-2模块采用了半可分离矩阵结构,该结构可以分解为q×q子矩阵,从而实现更高效的计算和处理。

5. 实验结果

6. 总结 & 未来工作

总之,我们提出的基于Mamba的世界模型DRAMA解决了基于RNN和Transformer的世界模型在基于模型的强化学习(RL)中所面临的关键挑战。通过实现O(n)的内存和计算复杂度,我们的方法可以使用更长的训练序列。此外,我们提出的新采样方法有效缓解了训练早期阶段的次优性,从而构建出一个既轻量级(仅700万个可训练参数的世界模型)又易于训练(可在标准硬件上训练)的模型。总体而言,我们的方法实现了与其他最先进RL算法相当的标准化分数,为基于模型的RL系统提供了一个实用且高效的替代方案。尽管Drama能够支持更长的训练和推理序列,但它并未表现出在Atari100k基准测试中相对于其他世界模型具有决定性优势。未来研究的一个有趣方向是探索在哪些特定任务中,更长的序列能够在基于模型的RL中带来卓越性能。尽管世界模型取得了进展,但基于模型的RL仍然面临多个挑战,如长期行为规划和学习、有信息的探索以及世界模型和行为策略联合训练的动态性。另一个有前景的未来研究方向是探究Mamba能在多大程度上帮助解决这些挑战。

对更多实验结果和文章细节感兴趣的读者,可以阅读一下论文原文~

本文仅做学术分享,如有侵权,请联系删文。

3D视觉交流群,成立啦!

目前我们已经建立了3D视觉方向多个社群,包括 2D计算机视觉 最前沿 工业3D视觉 SLAM 自动驾驶 三维重建 无人机 等方向,细分群包括:

工业3D视觉 :相机标定、立体匹配、三维点云、结构光、机械臂抓取、缺陷检测、6D位姿估计、相位偏折术、Halcon、摄影测量、阵列相机、光度立体视觉等。

SLAM :视觉SLAM、激光SLAM、语义SLAM、滤波算法、多传感器融合、多传感器标定、动态SLAM、MOT SLAM、NeRF SLAM、机器人导航等。

自动驾驶:深度估计、Transformer、毫米波|激光雷达|视觉摄像头传感器、多传感器标定、多传感器融合、自动驾驶综合群等、3D目标检测、路径规划、轨迹预测、3D点云分割、模型部署、车道线检测、Occupancy、目标跟踪等。

三维重建 :3DGS、NeRF、多视图几何、OpenMVS、MVSNet、colmap、纹理贴图等

无人机 :四旋翼建模、无人机飞控等

2D计算机视觉 :图像分类/分割、目标/检测、医学影像、GAN、OCR、2D缺陷检测、遥感测绘、超分辨率、人脸检测、行为识别、模型量化剪枝、迁移学习、人体姿态估计等

最前沿 :具身智能、大模型、Mamba、扩散模型等

除了这些,还有 求职 硬件选型 视觉产品落地、产品、行业新闻 等交流群

添加小助理: dddvision,备注: 研究方向+学校/公司+昵称 (如 3D点云+清华+小草莓 ), 拉你入群。

▲长按扫码添加助理:cv3d008

3D视觉知识星球

3D视觉从入门到精通 」知识星球,已沉淀6年,星球内资料包括: 秘制视频课程近20门 (包括







请到「今天看啥」查看全文