扫码关注我们
公众号 : 计算机视觉战队
扫码回复:
华为开源
,获取下载链接
作者:杨军
来源:知乎
纯技术讨论,不涉及其他,部分我拿不准的地方,会直接以
(?)
标识出来,欢迎菊花厂同学来指正解惑。
华为的运营同学蛮专业的,在回答里介绍了一些比较重要的技术细节,哪怕不看code,对于做这个方向的同学,也大概也能捕捉到里面的一些core concepts,YongCHN同学也对MindSpore的auto parallel部分有一个小调研,也可以供感兴趣的同学了解。
最近关于AI框架开源的讨论比较多,包括Jittor和MegEngine以及MindSpore,所以我想还是尽量能够提供一些有额外信息量的输入。
1.从框架的设计原则上,个人认为MindSpore还是蛮中规中矩的,能够看到明显从TF,PyTorch里分别借鉴了两家的经验(不要忽略MxNet的Gluon,其实也是把静态图和动态图结合在了一起,可惜这两年日渐势微,其实也从一个侧面反映出引擎技术核心以外因素对AI框架推广的重要性,参见这里的一个回答,更不要忽略了Theano这样的已经deprecated框架里早就有的静态图的渊源以及当年还小引领过风骚的Chainer)。都是做这个方向的,就不客套了,官宣里提到的将python的模型描述JIT编译成计算图的思想,在Google的JAX/auto-graph以及PyTorch的TorchScript相关的工作里比较早就已经touch了,不过有决心去从头build并落实,这个还是值得点赞。
2. 如YongCHN同学关于auto-parallel的分析里所说,在TF里,为了在python层加入auto-parallel的支持,同时还兼顾TF 2.0的eager mode,在python代码里引入了大量的
if (eager_mode)
的判断,让整个代码的可维护性受到了不小的影响,而TF为了加入Distribution Strategy的支持,对其现有的python构图代码也做了大量的修改和插桩,其实让整个框架的python层实现复杂了不少。而MindSpore的整个code base,在auto-parallel的实现上能够感觉到还是清晰了不少,引入了ANF这个IR层(关于ANF的设计理念之前自己并不了解,感谢@叶子豪 同学的输入,稍微补了些课,相关的background材料可以参见这里和这里 ,也对照着重新看了一下ir/anf.h里的定义,认为设计动机确实如ANF的背景材料所说是为了简化source-level的变换复杂性,关于细节,这方面不是我的expertise,如果有了解的同行能够share更精准的理解就太好了),把分布式策略的工作有不少沉到了以ANF IR这一层,避免把python API层改得太惨,这个设计认为是更合理的分拆,也确实在没有历史包袱的框架里更容易make这样的design choice。站在前人的肩膀上总是能够也应该有更好的创新。
3.在我的理解中,MindSpore的架构层次大体上可以拆分为
这几个大层次。
从git repo的组织来看,ME、算子库(部分算子库目前以闭源.so的方式提供)基本都放在MindSpore的主体repo里,Graph Compiler和Runtime放到了GE(Graph Engine)的repo里。按官网的doc似乎主要的fusion工作是在GE里完成的。但是呢,一些涉及到算子fusion的逻辑又散在了MindSpore主repo里,同时也在GE的repo里看到了一些fusion的pass。这让我稍微有些curious为什么是这样组织。从架构设计上,会感觉fusion相关的逻辑应该统一放到GE的repo里才自然。
目前我能够推测出的一个原因是GE主要target Ascend硬件,对于GPU相关的fusion不属于GE的范畴,所以在 MindSpore的repo里针对GPU的fusion做了一些比较直接的工作
(?)
。
4. 整个执行flow从代码结构来看是这样的,
-
用户完成python层的模型构图
-
调用python层的
train
API(
mindspore/train/model.py
)以后,会通过pybind11的接口触发
ExecutorPy::Compile()
,在这个函数的实现里,会将用户python模型描述的AST解析成ANF的格式(了解JAX和auto-graph的同学就能看到相似之处了)
-
完成ANF graph的构建之后,剩下的事情就比较自然了,在这个graph上bla bla bla做一系列的transformation,直到生成一个编译好的byte string,并将这个编译好的结果序列化保存下来
-
再通过一个pybind11的接口触发
ExecutorPy::Run()
,对上面的编译结果进行实际运行。
这个流程其实蛮自然的,至少在JAX和TF auto-graph项目里都有类似的作法,在PyTorch里其实也有类似的尝试,核心的难点我认为是在于Python的语法太灵活了,于是在不太起眼的
AST2XXX
(XXX可能是TF GraphDef,可能是JAX里的HLO graph也可能是MindSpore里的ANF graph)这个步骤可能会出现大量的corner case。PyTorch通过python解释器和PyTorch核心交互调用的方式(算是一种trace的方式)来牺牲一定性能但根本性的避免了这个陷阱,而JAX也好,auto-graph也好,以及MindSpore也好,在想获取动静结合的组合优势的同时,也需要pay for对应的engineering cost。