23年5月来自KAIST、伯克利分校、谷歌和加拿大多伦多大学的论文“Masked World Models for Visual Control”。
基于视觉模型的强化学习 (RL) 有可能实现机器人从视觉观察中进行样本高效的学习。然而,当前的方法通常端到端训练单个模型来学习视觉表征和动态,这样难以准确地模拟机器人与小目标之间的交互。这项工作引入一个基于视觉模型的 RL 框架,将视觉表征学习和动态学习解耦,即MWM。具体来说,训练一个具有卷积层和视觉Transformer (ViT) 的自动编码器,根据掩码的卷积特征重建像素,并学习一个对来自自动编码器的表征进行操作的潜动态模型。此外,为了对与任务相关的信息进行编码,为自动编码器引入一个辅助奖励预测目标。用从环境交互中收集的在线样本不断更新自动编码器和动态模型。该解耦方法在 Meta-world 和 RLBench 的各种视觉机器人任务中实现了最先进的性能。
Dreamer [15, 21] 是一种基于视觉模型的强化学习方法,它从像素中学习世界模型,并通过潜在想象训练A-C模型。
掩码自动编码器 (MAE) [13] 是一种自监督的视觉表示技术,它训练自动编码器使用由像素组成的随机掩码块重建原始像素。
掩码世界模型 (MWM),这是一个基于视觉模型的 RL 框架,用于通过分别学习视觉表征和环境动态来学习准确的世界模型。该方法重复 (i) 使用卷积特征掩码和辅助奖励预测任务更新自动编码器,(ii) 在自动编码器的潜空间中学习动态模型,以及 (iii) 从环境交互中收集样本。
如图所示,MWM 用从环境交互中收集的在线样本不断更新视觉表征和动态,通过(左)一个具有卷积特征掩码和奖励预测的自动编码器和(右)一个在自动编码器潜空间中潜动态模型的重复训练迭代过程。自动编码器参数在动态学习期间会不会更新。
使用 ViT 架构 [13、34、36] 进行掩码图像建模可以实现计算效率高且稳定的自监督视觉表征学习。
这促使采用这种方法进行基于视觉模型的强化学习,但用常用的像素patch掩码 [13] 进行掩码图像建模,通常很难学习patch内的细粒度细节,例如小目标。
虽然可以考虑小尺寸patch,但由于自注意层的二次复杂度,这会增加计算成本。
为了解决这个问题,训练一个自动编码器,根据随机掩码的卷积特征重建原始像素。与之前利用 patchify 主干 和随机掩码像素patch的方法不同,其采用卷积主干[14,40],通过一系列卷积层和一个平坦层来处理 ot 去获得 hct。
为了对重建目标可能无法单独捕获的任务相关信息进行编码,为自动编码器引入一个辅助目标,与像素联合预测奖励。在实践中,将一个额外的可学习掩码 token 连接到 ViT 解码器的输入,并利用相应的输出表示以线性输出头预测奖励。
引入早期卷积层可能会阻碍掩码重建任务,因为它们会在patch之间传播信息 [18],而模型可以利用这一点找到解决重建任务的捷径。然而,高掩码率(即 75%)会阻止模型找到这样的捷径并产生有用的表示。这也与 Touvron [18] 的观察结果一致,其中使用卷积主干 [45] 的掩码图像建模 [34] 可以在 ImageNet 分类任务 [19] 上实现与 MAE 中 patchify 主干相媲美的性能。
一旦学习了视觉表征,就会利用它们在自动编码器的潜空间中有效地学习动态模型。具体来说,从自动编码器中获得冻结表征 ztc,0,然后训练 RSSM 的变体,其输入和重构目标为 ztc,0。
由于视觉表征以抽象形式捕获高级和低级信息,因此模型可以通过重建它们而不是原始像素来更专注于动态学习。在这里,利用 ztc,0 的所有元素,而 MAE 只利用 CLS 表征进行下游任务。这使得模型能够从重建包含空间信息的所有表征中接收丰富的学习信号。
如下是MWM的算法伪代码: