深度学习IR/编译器的设计与思考
Table of Contents
传统程序语言 IR
SSA(Single Static Assignment) form
SSA 是过程式语言常用的 IR 格式,它要求每个变量有唯一确定的定义式。
考虑如下的例子
c = a + b if a < 5 then c = (c * 2) + 1 d = (c * 3) - 1 else c = (c * 2) + 2 d = (c * 3) - 1 e = c + 2
如果我们想做一些优化,例如公共子表达式消除,那么 c = (c * 2) + 1
和 c = (c * 2) + 2
中的 (c * 2)
可以合并,但是不能对两个 (c * 3) - 1
做类似操作,因为此处的 c
已经不是指向同一个变量了。
这也就是 SSA 限制条件的动机:我们需要知道每一个变量具体定义式的位置,以便对他们进行区分,同名变量的不同定义可以使用不同下标来表示,对于上面的例子,可以改写成如下形式:
c_0 = a_0 + b_0 if a_0 < 5 then c_1 = (c_0 * 2) + 1 d_1 = (c_1 * 3) - 1 else c_2 = (c_0 * 2) + 2 d_2 = (c_2 * 3) - 1 c_3 = phi(c_1, c_2) e_0 = c_3 + 2
注意到 e = c + 2
中的 c
可能由两个不同的控制流跳转而来,我们并不能确定此处的 c
的定义式具体是哪一个。为了解决这个问题,
SSA 引入了 \(\phi\) 函数,其参数为所有可能的定义式,而取值则在运行时决定。通过在每个控制流交汇的地方插入 \(\phi\) 函数,引入新的变量(上面例子中的 c_3
),保证了每个变量有唯一确定定义这一特性。
\(\phi\) 函数对硬件而言并不友好,它需要我们在运行时维护“当前 BB(基本块)从哪个 BB 跳转而来这一信息)”从而动态取值,一方面有额外开销,另一方面对硬件设计有额外要求。在软件层面则可以通过实现一个虚拟机以执行带 \(\phi\) 函数的 IR。
因此通常我们不把 SSA 作为程序最后生成代码的格式要求,而在 SSA 格式上进行一系列编译优化之后,会将代码转出 SSA 消去 \(\phi\) 函数,最后生成汇编语言交给硬件执行。
CPS(Continuaton Passing Style) form
函数式语言没有显式地制定求值顺序,为我们编译到机器码产生了困难,CPS 指定所有复杂表达式(类似于 (+ (* x y) (/ y x))
)的求值顺序,将其展开为单步计算(某个 variable,或者简单的 function apply,其中所有的参数必须是已经算出的 variable)+ continuation
(一个代表我们如何接下来处理刚才产生的结果的函数,包含剩下的所有计算流)的形式:
;; an expressions: (+ 1 (+ 2 (+ 3 4)) ;; where the continuation of (+ 3 4) is ;; < (+ 2 ...), cont1 > ;; where the hole ... means previous result (7=3+4 in this case), and cont1 is ;; < (+ 1 ...), halt > ;; The CPS form of the expr (1 + (+ 2 (+ 3 4)): (define id (lambda (x) x)) (+& 3 4 (lambda (v0) (+& 2 v0 (lambda (v1) (+& 1 v1 id)))))
要求每个函数带上一个额外参数 k
(continuation),代表如何使用当前函数计算的结果。
ANF form
ANF 是函数式语言常用的 IR 格式,其要求函数的所有参数必须是 trivial 的(不需要经过计算),因此复合函数通常需要展开成最基本的表达式,且每个表达式都需要命名。
e ::= let x = v in e | let x = v(v_1, ... v_n) in e | v | v(v_1, ..., v_n) | if v then e_1 else e_2 | letrec f_1, ... f_n in e f ::= x(x_1, ... x_n) = e
其中 v
代表已经命名的变量,而 e
代表表达式。
CPS, SSA, ANF 之间的关系
SSA Book 的第六章介绍了 SSA 的函数式表示(即 CPS)。
TODO 具体案例 (LLVM)
深度学习语言 IR
MLIR
以往的工作(JVM/LLVM)往往只提供了单层的抽象("one size fits all"):
- LLVM 几乎是 "C with vectors"
- JVM 是 "oo type system with gc"
这种 "one size fits all" 的方法为 从源程序转到 IR 提供了便利。
不同领域通常会设计自己的 IR 以完成解决某些特定问题的抽象,通常鱼龙混杂。
MLIR 试图提供一个统一的解决方案。
TODO 设计原则
后端优化器规则 – TableGen
TVM (TIR / Relay)
Relay
- Relay 关于 IR 选型的讨论:https://discuss.tvm.apache.org/t/choice-about-ir-ssa-or-anf/1757