在o1的整体框架篇中(https://zhuanlan.zhihu.com/p/773907223),我们从现有开源的论文和代码中(https://github.com/hijkzzz/Awesome-LLM-Strawberry),抽象出了o1可能的技术实现路径,如下图:
这里对于这张框架图我们不再做赘述,详情可以参见上面《框架篇》的文章链接。
我们之前说过,
这是一张高度抽象的框架图
,旨在说明o1官方技术报告中提到的“把更多算力花在inference阶段上,以提升模型的逻辑推理能力”的含义。而从本文开始,我们将以具体的算法去扩展这张框架图的细节。
今天我们要具体扩展的,就是框架图中的Inference部分(黄色块),
从框架图可知,Inference部分一般有两个作用:
作用1:直接对inference过程进行优化,具体的优化方法例如
:
PRM + some search methods
。其中PRM表示我们额外训练的、用于评估“模型中间步骤”而不是“模型答案结果”的奖励模型。我们在框架篇中给过使用这种优化方法的具体例子,这里不再赘述
MCTS(Monte Carlo Tree Search)
。使用蒙特卡洛树搜索的方法(AlphaGo中采用过),通过self-play的方式来
找到一条最佳的“原始问题->中间步骤->答案”路径。从广义上来说,PRM + some search methods的方法其实也算是一种MCTS-style类型的搜索方法
,只不过在MCTS中,我们通过“探索”步骤去估计结点的reward,而一个训练好的PRM则是直接替代了这种“探索-评估”过程。如果你对这些描述觉得抽象,那也没关系,MCTS是本文讲述的重点,我们马上会在文章中看到它的实现细节。
作用2:用于在post-training过程中筛选高质量的数据进行训练
。
从对目前的一些开源工作的总结中,我们发现,
在提升模型推理能力这一环节有一个核心的原则:尽量少用人工标注,多借助已有模型(base generator)本身的能力,去自动化地生产训练数据。然后再利用这些训练数据
,通过sft或者强化学习等等post-training的方法,去提升模型的推理能力。
为了保证这些自动化生成的训练数据的质量,我们可以引入Inference模块,帮助我们搜索出高质量的数据。
所以,
Inference模块可以看作是o1实现中的一块积木
。当你理解这块积木的目的、以及一些可能的实现方法后。你就可以按需要灵活把它组装在你心目中o1的任何一个环节。在网上关于o1的资料中,我们可能经常会看见“MCTS,self-play”这样的关键词,它其实就是这块黄色积木的一种实现方式。
不过笔者认为,o1走的不是纯靠优化inference的路线(即上图中的framework1),更可能走的是post-training + inference路线(即上图中的framework3,因为o1的技术报告中提过它把算力也花在了RL阶段上)。但是无论如何,了解这块积木的实现总是必要的。
在这篇文章中,我们将以微软在今年开源的
rStar
这个工作为例(https://github.com/zhentingqi/rStar),全
面从源码出发,来详细看下MCTS技术是如何运用在nlp的逻辑推理任务上的
(毕竟我们对MCTS的主要了解都来自AlphaGO,我们肯定非常好奇它要如何运作在自然语言上,特别是这个前提下它的搜索空间是什么)。
阅读本文不需要任何MCTS先验知识,文中会循序渐进地做介绍
。
一、为什么选择rStar
rStar的目的同样是提升模型的逻辑推理能力,
但是它走的是上图中的framework1,也就是纯靠inference的搜索优化来实现目标,同时它选择的是MCTS而非PRM + search methods的方法
。rStar作出这样选择的原因如下:
Base generator是个小模型
(SLM, Small Language Model)。rStar针对的是小模型场景,对于小模型来说,它本身的能力就不强,所以我们不能指望小模型能借助pretrain阶段的能力去生产高质量的训练数据,也即post-training自产自消的方法在小模型上难以走通。同时,在大部分业务场景下,我们可能也没那么多训练资源。
PRM的训练是费钱的
。如果非要用人工标注,那么大概率这个标注会花在PRM的训练上(参考框架篇中对openai的PRM的训练方式介绍)。对于身处贫苦环境中的我们,以及被落地okr催促的老板们,时间和金钱成本是能省则省。
正因为rStar走的是纯Inference的路线,所以更便于我们从”一块积木”的视角来理解框架图中的黄色块。同时,利好小模型的场景也更适合资源有限的我们。
最后,当然是rStar的代码完全开源,
方便我们一探所有的细节,少一些自己的想象(rStar的论文其实写得比较精简,少了很多细节的描述,也一定程度上造成代码不太好读)
。
二、按照人的思考方式构造一棵搜索树
这里我们先不谈MCTS的任何概念,我们只看:对于某个问题,你会采用什么样的思维链来解决它?
假设我们有一个简单的问题:
user_question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
为了解决它,我们可能有如下思考方式(所有的思考方式都以字母A开头,表示Action)
2.1 A1(propose a one-step-thought)
我们会做过程的拆解,每次提出一个推理step,直到生成最后的答案。我们记这种思考方式为A1。例如:
A1(propose a one-step-thought)### Instruction: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?### Response: Let's think step by step. Step 1: Start with the number of cars that are already in the parking lot, which is 3 cars. Step 2: Add the number of cars that arrive, which is 2 cars. Step 3: Add the numbers together. there are 3 cars + 2 cars = 5 cars in the parking lot. Step 4: The answer is 5.
观察上面的steps,我们会发现:
总是以Let's think step by step.开头
每个step的形式是“该step的推理文字+该step的答案”。例如step1中,在一段推理相关的文字结束后,能提取出“3”这个数字答案
最后一个step以“The answer is”开头,表示产出了原始问题的最终答案。
2.2 A2(propose the remaining thought steps)
对于一些简单的问题,我们可能并不会步步思考。我们会一次性通过一些简单的推理后直接给出答案,例如:
### Instruction: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?### Response: Let's think step by step. There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is: 5.
2.3 A3 (propose next sub-question along with its answer)
有时候,我们会把原始问题拆解成很多子问题,然后回答一个个子问题,最终给出答案,例如:
Question 1: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? Question 1.1: How many cars are there in the park before? Answer 1.1: There are 3 cars in the park before. Question 1.2: How many cars arrive then ? Answer 1.2: 2 more cars arrive. Question 1.3: Now we can answer the question: how many cars are in the parking lot? Answer 1.3: There are 3 + 2 = 5 cars in the parking lot now. The answer is 5.
其中,Question1是原始问题,其余是拆解的子问题。其中,Question 1.3属于终结类型的子问题,因为回答它就等于回答了最终答案。这种拆解子问题的方式更适合用来解决困难问题,我们的例子比较简单,这里只是展现出一个形式。
2.4 A4 (Answer the sub-question again)
这种方式将和A3一起配套使用,例如,对于A3的Question1.1,你可能并不确定Answer1.1是否正确,这时你想重新再思考一次Answer1.1的答案。由于此时你只是对某一个子答案做修正,因此你可能采用A2(propose the remaining thought steps)的方式,做一些简单的推理,重新取得Answer1.1。此时相当于把Answer1.1用A2例子中的输出结果进行替代,这里不再给出具体例子。
2.5 A5(Rephrase the question/sub-question)
有时我们在做题时,通常会在大段的原始题目描述中,把关键信息提取出来,例如:condition1..., condition2等等。我们可以先通过这种方式改写原始题目/子题目,然后再做回答。这个比较好理解,同样也不再给出具体的示例。
2.6 整合:构造一颗搜索树
总结一下,目前为止,我们按照人类的思维方式,总结出了人类解决一个问题时可能采用的5种方法:
A1(propose a one-step-thought)
:步步推理,每一步都有一些中间答案,然后在最后一步中得到最终答案
A2(propose the remaining thought steps)
:一次性推理完毕,直接得出最终答案
A3 (propose next sub-question along with its answer)
:
将原始问题拆解成若干子问题并做相关回答。最后一个子问题的答案即是最终答案(和A1有些类似,但采取的是subquestion-subanswer这种指示方式)
A4 (Answer the sub-question again)
:有时A3中某个子问题的回答不一定可信,我们尝试重新回答它。这时我们会采用A2的模版,重新回答这个子问题
A5(Rephrase the question/sub-question)
:重新复述一个原始问题/子问题。例如去掉大段文字表述信息,只把关键部分提取成condition1..., condition2之类的形式,用这个形式当作新的问题。
在代码操作中,我们会按2.1~2.5的示例,构造相应的prompt来指示模型执行不同的动作。下图给出了A1的prompt示例,更多例子大家可以参见源码中rStar/prompts部分:
当人解决问题时,可能会根据问题的难度,决定不同的解决模式,但是当我们采用模型进行搜索时,模型是很难预知问题难度的,
所以我们总是希望:模型能够尽可能地把这些解决方式(Action)都探索一遍。
那么接下来,我们就
配合着rStar的源码
,一起来看下这棵搜索树长什么样子(
这里我们不使用论文中的图,因为它缺少了太多细节,我们直接从源码出发重新绘制
):
我们先看一些基本信息:
方形node表示终止结点(leaf node),例如图中的cot结点(A2)。但注意,不是只有cot结点才是leaf node。例如A1中的最后一个step,A3中的最后一个子问题-子回答都可以成为leaf node。
虚线表示选择性探索(根据你的脚本配置决定),实现表示必须性探索。
接下来我们来看图中的更多细节:
我们从根结点(第0层)出发,根结点是用户的原始问题,对于根结点来说:
A1(a step),A5(rephrase)是选做的,A2(remaining steps,图中按源码的命名方式称为cot),A3(next subqs and subas)是必做的。其中,经过A5后,相当于从一个全新的用户问题出发,所以A5之后创建分支的规律和根结点一致,因此图中不再画出。现在我们观察图中的第1层,也就是根结点所有的子节点
先看第1层中的A1部分(蓝色块)
,它表示对原始问题并行采样若干个step1(代码中默认采样3个)。每个step1都是一个A1类型结点。并行采样的目的也是为了拓宽搜索空间。从直觉上理解,对于一个原始问题,当我们决定采取A1的思考方式解决它时,不同的人产出的step1是不同的,所以这里我们要做并行采样。
再看第1层中的A2部分(绿色块)
,它表示对原始问题做简单推理后,一步生成最终结果。这里同样采用了并行采样的方式(默认值为32)。
因为已经生成最终答案,所以A2结点都是leaf node,它不会再往下做任何扩展。
再看第1层中的A3部分(红色块)
,
它表示对原始问题产出next subquestion及其对应的subanswer
。这里同样采用并行采样的方式(默认值为3),例如对原始问题,我们并行采样三个Question1.1和Answer1.1。
再看第1层中的A5块(灰色块)
,它表示对原始问题进行重述,前面已提过,这里不再赘述。
那么根结点的子节点中为什么没有A4(re-answer subquestion)呢?
因为根据A4的性质,A4一定只能出现在A3之后,所以这里不会有。
接下来我们从第1层出发,以第1层为例,探索下不同类型的结点可以生成什么类型的子节点,以及最终可能的leaf node类型
。只要搞清楚了第1层,其余层就可以类推了。
它只能产生A1(选做)和A2(必做)两种类型的子节点
。也就是说,基于当前step,我们必须产出一个一步到位的推理结果,选择性产出一个step by step的结果。
A1结点最后的leaf node既可以是A2(一步到位产出了最终结果),也可以是A1类型(最后一个step的结果)
A1类型结点的值包含“该step推理文本 + 该step阶段性答案”,参见2.2中的例子
对于A2(remaining thoughts)类型结点
如前文所说,它已经一步到位产出了答案,所以是leaf node,它不会再有子节点
对于A3(subqs and subas)类型结点
它可以产生A1(选做),A2(必做),A3(必做),A4(必做)类型子节点
图中第1层,我们只画出了并行采样出的第1个A3结点的子节点情况,其余并行采样的结果也是类推
,因此图中没有画出,只用简单的省略号表示。(这个省略号其实也应该画在第2层,因为图的尺寸限制画偏了,特此说明)
A3结点往下延伸的leaf node,可以是A1(最后一个step结果),也可以是A2(一步到位产出最终结果),也可以是A3(最后一个subqs + subas结果,参考2.3示例)。
对于A4(re-answer subqs)类型结点(这一部分我们参考第2层)
它可以产生A1(选做),A2(必做),A3(必做)类型子节点
当我们执行A4时,你可以理解成只是重新修改了它的parent层的sub answer。
它的leaf node可以是A1(最后一个step结果),也可以是A2(一步到位产出最终结果),也可以是A3(最后一个subqs + subas结果,参考2.3示例)
总结一下,到目前为止我们已经解决了:
我们先根据人类思考问题的模式,设置搜索动作空间(Action,缩写为A)。
搜索空间中的不同动作之间可能有前-后(parent-child)的依赖条件,我们根据这些条件,决定了一个完整的搜索树要长什么样。
但是,仍有一些重要但未解的问题:
这棵树是我们站在上帝视角,(基本)穷尽所有的动作可能后构造出来的。那么对于模型,它应该怎么按我们的想法构造出这棵搜索树呢?
有了这棵搜索树后,我们要如何从根结点(user question)开始,选择一个最佳的推理路径并产出最后的答案呢?
为了解决这两个问题,现在我们可以请出MCTS这个算法了
。
三、使用MCTS搜索最佳推理路径
3.1 使用rollout构造搜索树
对于模型来说,现在它将从原始问题出发,构造一棵搜索树。我们先来看从根结点出发,模型构造搜索树的过程:
对于根结点来说:
执行select步骤
。选中根结点(我们马上就来看select更多细节,目前为止我们只用关注这一步select到了根结点)
执行expand步骤
。按照第二节中我们说的各结点间的依赖规则,为被选中的结点(这里是根结点)创建所有可能的子节点。为了绘图简便,这里我们略去了2.6节中所述的“并行采样”的过程,但实操中依然是并行采样的!
执行simulate步骤
。随机选择一个子节点,重复执行“expand-随机”步骤,
直到遇到leaf node或者达到设定的最大搜索深度为止。注意,只有两种类型的node可以成为leaf node(这和第二节中我们列的leaf node的理想情况有些许不同)
。
Terminal A3 node
:如果一个subquestion结点是最后一个子问题(“最后一个的含义”是,子问题中包含原始问题,或者子问题以“Now we can answer the question"开头,参见2.3示例。能做到这一点是因为我们通过相关prompt来指示模型生成结果 )
Terminal A2 node
:这个node本身就是一步推理产出最终结果,前面已经说过,这里不赘述
执行backprop步骤
。这一步我们将计算leaf node的reward,同时将本次搜索路径上所有node.reward += leaf_node.reward,node.freq += 1,其中freq表示node被访问的次数。
那么如何计算leaf node的reward呢?
Terminal A3 node reward
:对于A3类型的leaf node,我们对这最后一个子问题,并行采样若干个子回答。假设我们采样n个子回答,这些回答中指向答案x,y,z的条数分别是a,b,c(n = a+b+c),那么x答案的占比就是a/n,以此类推,我们选择占比最大的那个答案作为最终答案,并将这个占比作为reward。
Terminal A2 node reward
:对于A2类型的leaf node,我们则直接在它的所有并行采样结果中计算答案占比,计算方式同上。
这样一轮select + expand + simulate + backprop的步骤,就称为1次rollout
。不难发现,在1次rollout过后,我们构造出了一部分搜索树(这里我们先只谈构造,不谈搜索,大家不要着急)
接下来我们执行第2轮rollout,继续构造我们的搜索树
(这里不再画图了,我们直接从1st rollout的图例中想象一下):
执行第2轮rollout的select步骤
。第2轮rollout将从第1轮backprop后构造的那棵搜索树出发。同样从根结点开始向下选择,我们先走到第1层,发现有3个子节点都没被探索过,这时我们随机选择一个子节点,例如图中第1层的A5(rephrase),这个子节点将被用作expand。到这里,我们再深度总结一下select步骤要做的事情:
每次都从根结点出发,向下逐层探索(explore),直到找到一个未被探索过的结点为止。
如果从根结点出发,发现某一层(比如第1层)所有的结点都被探索过了。那么我们就计算每个结点的UCT值
(在3.2节中会细说,这个UCT值可以理解成用于计算一个结点的探索价值,它由结点的reward、freq和用于控制探索权重的超参C决定)。我们选择UCT值最大的结点,向下层继续搜寻,以此类推。
所以,总结来看,select步骤的目的就是尽可能找到一条未被探索、或者具有最高探索价值的路径
。以便后续沿着它往下扩展,生成更好的搜索树。
执行第2轮的expand、simulate、backprop步骤
。道理同上,不再赘述
。
这里额外再提一句,生成搜索树的每一层时,我们都需要用前面所有层的推理步骤作为上文,传递给模型做生成
,大家可以自行阅读源码找到构造上文的更多细节,这里不再额外介绍。
好,到这里为止我们已经理清单轮rollout的概念了,
以此循环往复,在执行若干轮rollouts(代码默认值为16)后,我们就有一棵相对完整的搜索树了,接下来我们就可以基于这棵树去找到一条最佳的推理路径了
。但是在介绍具体的搜索方法之前,让我们再来看看,如何计算一个结点的UCT值(UCT值越大,该结点被探索的价值越大)。
3.2 计算结点的UCT值
一个结点的UCT值计算方式如下:
Q:截止到本轮rollout为止,该结点的累积reward