Triton IR
Triton IR语法
Triton IR的语句遵从MLIR Dialect的语法定义规范,示例如下:
%3 = tt.splat %1 : i32 -> tensor loc(#loc5)
其中:
%0:右边expression的结果值的名字(Value的name)
tt:表示Dialect名称空间为tt(Triton)
splat:operation的名字
%1:operation的输入
i32:%1的类型
tensor:operation的结果类型(即3%的类型)
loc(%loc5):对应源码的行号,调试信息。
如下是一个pytorch cat算子的Triton DSL(inductor产生)
@triton.jitdef triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 3645440 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = tl.full([XBLOCK], True, tl.int1) x0 = xindex % 890 x1 = (xindex // 890) x2 = xindex tmp0 = x0 tmp1 = tl.full([1], 0, tl.int64) tmp2 = tmp0 >= tmp1 tmp3 = tl.full([1], 390, tl.int64) tmp4 = tmp0 = tmp3 tmp7 = tl.full([1], 890, tl.int64) tmp8 = tmp0 < tmp7 tmp9 = tl.load(in_ptr1 + ((500*x1) + ((-390) + x0)), tmp6, eviction_policy=\'evict_last\', other=0.0) tmp10 = tl.where(tmp4, tmp5, tmp9) tl.store(out_ptr0 + (x2), tmp10, None)\'\'\', device_str=\'cuda\')
编译生成的Triton IR如下::
#loc = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":18:0)module { tt.func public @triton_(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":18:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":18:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":18:0), %arg3: i32 {tt.divisibility = 16 : i32} loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":18:0)) attributes {noinline = false} { %cst = arith.constant dense : tensor loc(#loc1) %cst_0 = arith.constant dense : tensor loc(#loc1) %cst_1 = arith.constant dense : tensor loc(#loc1) %cst_2 = arith.constant dense : tensor loc(#loc1) %cst_3 = arith.constant dense : tensor loc(#loc1) %cst_4 = arith.constant dense : tensor loc(#loc1) %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) %0 = tt.get_program_id x : i32 loc(#loc2) %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor loc(#loc4) %3 = tt.splat %1 : i32 -> tensor loc(#loc5) %4 = arith.addi %3, %2 : tensor loc(#loc5) %5 = arith.remsi %4, %cst_4 : tensor loc(#loc6) %6 = arith.divsi %4, %cst_4 : tensor loc(#loc7) %7 = arith.extsi %5 : tensor to tensor loc(#loc8) %8 = arith.cmpi slt, %7, %cst_3 : tensor loc(#loc8) %9 = arith.muli %6, %cst_2 : tensor loc(#loc9) %10 = arith.addi %9, %5 : tensor loc(#loc10) %11 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc11) %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor loc(#loc11) %13 = tt.load %12, %8, %cst_1 evictionPolicy = evict_last : tensor<1024x!tt.ptr> loc(#loc12) %14 = arith.cmpi sge, %7, %cst_3 : tensor loc(#loc13) %15 = arith.muli %6, %cst_0 : tensor loc(#loc14) %16 = arith.addi %5, %cst : tensor loc(#loc15) %17 = arith.addi %15, %16 : tensor loc(#loc16) %18 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc17) %19 = tt.addptr %18, %17 : tensor<1024x!tt.ptr>, tensor loc(#loc17) %20 = tt.load %19, %14, %cst_1 evictionPolicy = evict_last : tensor<1024x!tt.ptr> loc(#loc18) %21 = arith.select %8, %13, %20 : tensor, tensor loc(#loc19) %22 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc20) %23 = tt.addptr %22, %4 : tensor<1024x!tt.ptr>, tensor loc(#loc20) tt.store %23, %21 : tensor<1024x!tt.ptr> loc(#loc21) tt.return loc(#loc22) } loc(#loc)} loc(#loc)#loc1 = loc(unknown)#loc2 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":20:28)#loc3 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":20:33)#loc4 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":21:36)#loc5 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":21:23)#loc6 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":23:18)#loc7 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":24:20)#loc8 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":30:18)#loc9 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":31:35)#loc10 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":31:41)#loc11 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":31:30)#loc12 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":31:46)#loc13 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":32:19)#loc14 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":35:35)#loc15 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":35:51)#loc16 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":35:42)#loc17 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":35:30)#loc18 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":35:57)#loc19 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":36:33)#loc20 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":37:25)#loc21 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":37:37)#loc22 = loc(\"/tmp/torchinductor_vincent/yx/cyxlesomwjvogzqnvxmuj2p2z2gml7hud473ghmveu4fg6jbtlmz.py\":37:4)
Triton IR依赖的Dialects
编写完triton程序后,导出的IR中,可以看到不止有triton IR,还包含其他的MLIR Dialects,其中包含:
-
Arith: addf, addi, andi, cmpf, cmpi, divf, fptosi, …
-
Math: exp, sin, cos, log, …
-
StructuredControlFlow(scf): for, if, while, yield, condition
-
ControlFlow(cf): br, cond_br
Triton IR Operations
tt.call
(triton::CallOp)
语法:
operation ::= `tt.call` $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
tt.call表示对同一个符号作用域内的函数的直接调用。
示例:
%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32
tt.func
(triton::FuncOp)
function声明或定义,function是一个SSACFG region。
function内的Operation不能隐式地捕获function外定义的值。所有外部引用都必须通过arguments或者attribute来传递。在MLIR中,function的arguments是通过第一个block的block arguments来表达的。
示例:
// External function definitions.tt.func @abort()tt.func @scribble(i32, i64, memref) -> f64// A function that returns its argument twice:tt.func @count(%x: i64) -> (i64, i64) attributes {fruit: \"banana\"} { return %x, %x: i64, i64}// A function with an argument attributett.func @example_fn_arg(%x: i32 {swift.self = unit})// A function with a result attributett.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})// A function with an attributett.func @example_fn_attr() attributes {dialectName.attrName = false}
SSACFG region
SSACFG region内的语句满足SSA形式,且不包含子Region(既不能包含scf.if/scf.for等),如下就是一个SSACFG region:
func.func @example(%a : i32) -> i32 { // 这是一个 SSACFG Region %cmp = arith.cmpi slt, %a, %c10 : i32 cond_br %cmp, ^bb1, ^bb2^bb1: %x = arith.addi %a, %c1 : i32 br ^exit(%x : i32)^bb2: %y = arith.subi %a, %c1 : i32 br ^exit(%y : i32)^exit(%result : i32): return %result : i32}
如下不是一个SSACFG Region:
scf.if %cond { // 这里是一个新的 Region(嵌套) scf.yield}
Block Arguments
对如下函数:
func.func @foo(%arg0: i32, %arg1: f32) -> f32 { // 函数体直接使用 %arg0, %arg1 %result = arith.addf %arg1, %arg1 : f32 return %result : f32}
在MLIR的内部实现里,是把function的arguments作为function内第一个基本块(entry block)的 block arguments 来存储:
func.func @foo() -> f32 {^bb0(%arg0: i32, %arg1: f32): // ← 参数实际属于 entry block %result = arith.addf %arg1, %arg1 : f32 return %result : f32}
这是因为MLIR的设计要求所有 SSA 值都由某个 block 或 op 产生,这样做也解决了LLVM IR中的phi node的问题。
在LLVM IR中,通过phi node来汇聚不同前驱路径的值,示例如下:
entry: br i1 %cond, label %left, label %rightleft: br label %mergeright: br label %mergemerge: %x = phi i32 [ %v1, %left ], [ %v2, %right ] ; ← φ 节点 ret i32 %x
在MLIR中,通过block arguments,可以实现同等的效果:
func.func @foo(%cond: i1, %v1: i32, %v2: i32) -> i32 { cf.cond_br %cond, ^left, ^right^left: cf.br ^merge(%v1 : i32) // 把 %v1 作为参数传给 merge^right: cf.br ^merge(%v2 : i32) // 把 %v2 作为参数传给 merge^merge(%x : i32): // ← block argument 取代 φ return %x : i32
tt.return
(triton::ReturnOp)
语法:
operation ::= `tt.return` attr-dict ($srcs^ `:` type($srcs))?
表达function的返回操作,拥有变长个数的操作数,操作数的个数和类型必须和function的签名匹配。
示例:
tt.func @foo() : (i32, f8) { ... tt.return %0, %1 : i32, f8}
tt.addptr
(triton::AddPtrOp)
语法:
operation ::= `tt.addptr` $ptr `,` $offset attr-dict `:` type($result) `,` type($offset)
张量或标量指针地址线性偏移计算。
示例:
%base = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr>%idx = tt.make_range {start = 0, end = 1024} : tensor// 生成偏移地址%ptrs = tt.addptr %base, %idx : tensor<1024x!tt.ptr>, tensor// 加载数据%vals = tt.load %ptrs : tensor
tt.advance
(triton::AdvanceOp)
语法:
operation ::= `tt.advance` $ptr `,` `[` $offsets `]` attr-dict `:` type($result)
对 !tt.ptr<tensor>
类型的指针按给定的 多维偏移量 进行偏移计算,返回一个新的张量指针。
示例:
scf.for %i = %c0 to %c128 step %c32 iter_args(%tile_ptr = %base_ptr) -> (!tt.ptr<tensor>) { // 使用当前 tile %vals = tt.load %tile_ptr : !tt.ptr<tensor> // 推进到下一个 tile(第1个维度上推进 32,第2个维度保持不变) %next_ptr = tt.advance %tile_ptr, [%c32_i32, %c0_i32] : !tt.ptr<tensor> scf.yield %next_ptr : !tt.ptr<tensor>}
tt.assert
(triton::AssertOp)
语法:
operation ::= `tt.assert` $condition `,` $message attr-dict `:` type($condition)
tt.assert作用在device侧,接收1个
condition(i1
类型的标量或张量)和一个string. 如果condition为false,则打印message并终止程序。
示例:
%in_bounds = arith.cmpi slt, %idx, %size : i32tt.assert %in_bounds, \"index out of bounds\"
TODO
参考资料:
TritonOps — Triton documentation