Triton & Triton Ascend
导言
- Ascend上训练编译成全图有功能问题,导致下发问题并不能像GPU一样完全解决;
- 在浦江实验室的经验是,triton确实能快速拿到2~3倍的收益,如果算子还有问题就能考虑
核心概念¶
Grid¶
¶
基础语法¶
编译控制¶
@triton.heuristics(...)`:智能推导(元编程)
一句话解释:根据传入内核的参数,在编译前自动计算出一些静态的布尔标志位(Flags)或常量。
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
- 当你在 Python 里调用这个 Triton Kernel 时,会传入一堆参数(比如张量
g,张量cu_seqlens等)。 triton.heuristics会在真正把代码丢给 GPU 编译之前,先运行这几个lambda匿名函数。args['g'] is not None:如果你调用时传了g张量,USE_G就会被硬编码为True;如果没传,就是False。args['cu_seqlens'] is not None:同理,如果传了累积序列长度(通常用于变长序列,Variable Length Attention),IS_VARLEN就是True。
为什么要这样做?(核心机密:死代码消除)
在 GPU 编程中,在最内层循环写 if 分支是非常影响性能的。
通过 heuristics,USE_G 变成了一个编译期常量(constexpr)。
* 如果 USE_G = False,Triton 编译器在编译时,会直接把代码里所有 if USE_G: 下面的代码全部删掉(Dead Code Elimination, 死代码消除)!
* 这样,你只需要写一份代码,就能自动生成“带门控”和“不带门控”两套极其纯净、没有任何 if 开销的 GPU 机器码。
@triton.jit(do_not_specialize=['T']):编译与特化控制
一句话解释:告诉 Triton 编译器:“这是一个 GPU 函数,请帮我编译它。但是,请千万不要对参数 T 进行死板的过度优化!”
JIT 就是 Just-In-Time(即时编译)。它把 Python 语法翻译成底层的 GPU 汇编指令(PTX 或 SASS)。这是写 Triton 必带的装饰器。
do_not_specialize?(特化控制)是处理动态 Shape 时的救命稻草!
-
Triton 的默认行为(疯狂特化):为了让 GPU 跑得最快,Triton 极其激进。如果你传进来的参数
T(比如序列长度 Sequence Length)这次是1024,Triton 就会针对1024这个数字编译出一个专属的二进制程序。- 这很好,因为编译器知道确切数字,能做好多极致的底层优化(比如循环展开)。
- 但是灾难来了:如果下一轮训练,你传进来的
T变成了1025,或者1028。Triton 一看:“哎呀,数字变了,之前的 1024 专属程序不能用了!” 于是它会从头开始重新编译一次代码。 - 编译是非常非常慢的(可能需要零点几秒甚至几秒)。如果你每次输入的序列长度都不一样,模型就会因为疯狂重新编译(Re-compilation Storm)而卡得像 PPT 一样。
-
加上
do_not_specialize=['T']后:- 你明确告诉了 Triton 编译器:“老哥,
T这个变量是一个普通的动态数字,它随时会变。请把它当成一个普通的寄存器变量,不要把它硬编码(Hardcode)进二进制程序里!” - 这样,不管接下来你的
T是 1024 还是 2048,只要数据类型没变,Triton 就会一直复用第一次编译好的那套 GPU 代码,彻底解决了运行过程中的卡顿问题。
- 你明确告诉了 Triton 编译器:“老哥,
Memory & Control¶
- tl.cdiv(K, BK) 向上取整除法 (Ceiling Division)。相当于数学上的 \(\lceil K / BK \rceil\),或者 Python 里的 (K + BK - 1) // BK。
- tl.make_block_ptr(...) 数据的基地址(base),整个大张量的形状(shape),大张量在内存里是怎么跨步的(strides),你要切的块有多大(block_shape),以及当前切块的起始坐标(offsets)。
- tl.load(ptr, boundary_check=...) 从全局显存(VRAM)加载数据到 GPU 的极速寄存器(SRAM)中。把上面定义好的指针 p_k 对应的数据块加载进来,变成一个二维张量 b_k。boundary_check=(0, 1) 是它的神仙功能:如果你的数据大小不是块大小的整数倍(比如最后一块超出了边界),它会自动帮你把越界的部分填充为 0,再也不用手动写 mask 了!
tl.make_block_ptr
在 tl.make_block_ptr 出现之前,写 Triton 代码简直就是一场“计算指针偏移量”的噩梦,稍有不慎就会遇到内存越界(Segfault)直接导致程序崩溃。而 tl.make_block_ptr 就像是一个智能的“取景框”,你只需要告诉它大背景是什么样,取景框多大,放在哪,它就会自动帮你把里面的数据完好无损地搬运到显存(SRAM)中。
在 NVIDIA 的 Hopper 架构(比如 H100)上,这个 API 甚至会直接调用硬件级别的 TMA (Tensor Memory Accelerator) 单元,速度极快。
我们把你之前代码里的这行经典调用拿出来,把它“解剖”开:
# K shape [B, T, H, K] 1, 65536, 32, 128
p_k = tl.make_block_ptr(
base=k + k_batch_off + (bos * H + i_h) * K,
shape=(T_local, K),
strides=(H * K, 1),
offsets=(i_t * BT, i_k * BK),
block_shape=(BT, BK),
order=(1, 0)
)
这个函数总共有 6 个核心参数。为了方便理解,你可以想象我们有一张巨大的清明上河图(整个大张量),而我们要拿着一个小放大镜(Block)去截图。
-
base(基地址) -
字面意思:大张量在 GPU 全局内存(HBM)中的物理起始地址。
- 工程师大白话:“画卷的绝对起点在哪?”
-
在代码中:
k + k_batch_off + (bos * H + i_h) * K- 这里通常不仅是整个大矩阵的首地址(
k),往往还包含了 batch 偏移量(k_batch_off)和 多头注意力(Multi-Head)的头偏移量((bos * H + i_h) * K)。 - 你可以理解为:在这个大矩阵中,我们已经把大本营扎在了当前 Batch、当前 Head 的第 0 行第 0 列的位置。
- 这里通常不仅是整个大矩阵的首地址(
-
shape(全局形状) -
字面意思:当前你要操作的这个“完整数据张量”的全局大小(通常是二维或多维的 Tuple)。
- 工程师大白话:“这幅画的总长度和总宽度是多少?”
-
在代码中:
(T_local, K)- 告诉 Triton,我们现在面对的是一个总共有
T_local行(比如序列长度 1024),总共有K列(比如特征维度 128)的巨大矩阵。 - 为什么需要这个? 主要是为了安全。Triton 知道大矩阵的边界后,当你截取到边缘时,它配合
boundary_check就能自动进行安全填充(Pad 0),防止越界。
- 告诉 Triton,我们现在面对的是一个总共有
-
strides(跨步/步长) -
字面意思:在内存这一维的线性空间里,你想在这个矩阵中移动 1 行 或 1 列,内存指针需要跳过多少个元素?
- 工程师大白话:“物理内存是一条直线,怎么折叠成这幅二维画的?”
-
在代码中:
(H * K, 1)- 这是一个元组
(stride_row, stride_col)。 stride_col = 1:在同一行里,向右走 1 格,内存地址加 1。这说明数据在内存里是连续挨着存的(这叫 Row-Major,行优先)。stride_row = H * K:如果我想走到下一行,内存指针要跳过H * K个元素(跳过所有头在这个 Token 上的数据)。
- 这是一个元组
-
offsets(切块偏移量) -
字面意思:我们要取的这一个小 Block,它的左上角在整个大矩阵里的相对坐标是多少。
- 工程师大白话:“放大镜的左上角对准画里的哪个坐标?”
-
在代码中:
(i_t * BT, i_k * BK)- 由于外面有
for循环,坐标是动态的。 - 行坐标:当前处理的 token 块索引
i_t乘以块大小BT(比如第 2 块的起始行就是2 * 64 = 128)。 - 列坐标:当前处理的维度块索引
i_k乘以维度大小BK(比如第 1 块的起始列就是1 * 32 = 32)。
- 由于外面有
-
block_shape(块形状) -
字面意思:我们要从显存里一口气取进寄存器(或共享内存 SRAM)的小数据块的大小。
- 工程师大白话:“放大镜的框到底有多大?”
-
在代码中:
(BT, BK)- 取出一个
BT行、BK列的小矩阵(比如64 x 32)。 - 注意:
block_shape里的值必须是 2 的幂次方(16, 32, 64, 128...)并且在编译时就要确定,不能是动态变量。
- 取出一个
-
order(连续性顺序 / 内存读取模式) -
字面意思:定义这个 Block 里的数据在物理内存上,哪个维度的元素是紧紧挨在一起的(连续的)。
- 工程师大白话:“GPU 加载数据时,最内层的高速循环应该沿着哪个方向跑?”
- 在代码中:
(1, 0)- 这是一个非常影响性能(Performance)的参数,决定了内存的“合并访问(Memory Coalescing)”效率。
- 数字代表维度的索引。对于 2D 张量,
0是行维度,1是列维度。 order=(1, 0)的意思是:第1维(列维度)变化最快,第0维(行维度)变化最慢。换句话说,数据是按行读取的(即每行的数据在物理上挨在一起)。- 反过来,如果数据在内存里是列优先存储的(比如经过转置的矩阵),你需要传入
(0, 1),以保证 GPU 一次性能“吞”下连续的内存块,而不是像跳蚤一样在内存里乱跳。
一张图总结
把它们拼接在一起,make_block_ptr 就在 GPU 后台做了一件这样的事:
[ 全局显存 HBM ]
Base: 地址起点
+----------------------------------------+
| 大矩阵 Shape: (T_local, K) |
| 步长 Strides: (H*K, 1) |
| |
| [Offsets: (i_t*BT, i_k*BK)] |
| +---------------+ |
| | | |
| | 小块 (BT, BK) | <-- Order: |
| | (Block_shape) | (1,0)横向|
| +---------------+ |
| |
+----------------------------------------+
tl.load(p_k, boundary_check=(0, 1)):
* 如果:你的框(放大镜)跑到了画卷的外面(比如最后一块的尺寸不够 BT 和 BK 了)。
* 结果:Triton 会因为看到了 boundary_check=(0, 1),自动查阅 shape 和 offsets,把超出的部分智能用 0.0 填补,这就是 make_block_ptr 被誉为神级 API 的原因。再也不用手写繁琐的越界 Mask 了!
Math & Matrix¶
- tl.trans(b_k) 矩阵转置 (Transpose)。
- tl.dot(b_k, tl.trans(b_k)) 矩阵乘法
Indexing & Masking¶
- tl.arange(start, end) 的语法作用与 NumPy (np.arange) 或 PyTorch (torch.arange) 非常相似:它用于生成一个包含连续整数的 一维 Tensor(张量)。
- [:, None] 和 [None, :]
- row_indices = tl.arange(0, BT)[:, None] 会把 [0,1,2] 变成一个列向量(3行1列)。
- col_indices = tl.arange(0, BT)[None, :] 会把它变成一个行向量(1行3列)。
tril_mask = (row_indices > col_indices).to(tl.float32)
Shape 变化(广播发生):这里执行了一个比较运算:[4, 1] 的矩阵 > [1, 4] 的矩阵。 神器的底层逻辑:Triton(或PyTorch)发现形状对不上,但一边是 4x1,另一边是 1x4,符合广播规则。于是它把两者都“拉长”成 [4, 4]。
row_indices 被横向复制了 4 次:
\(\begin{bmatrix} 0&0&0&0 \\ 1&1&1&1 \\ 2&2&2&2 \\ 3&3&3&3 \end{bmatrix}\)
col_indices 被纵向复制了 4 次:
\(\begin{bmatrix} 0&1&2&3 \\ 0&1&2&3 \\ 0&1&2&3 \\ 0&1&2&3 \end{bmatrix}\)
两两做 > 对比:
第0行:0>0(False), 0>1(False), 0>2(False)... 第1行:1>0(True), 1>1(False), 1>2(False)... 第2行:2>0(True), 2>1(True), 2>2(False)...
最终 tril_mask 的形状是 [4, 4](即 [BT, BT]),内容长这样(True转成了1.0,False转成了0.0):
\(\begin{bmatrix} 0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 \end{bmatrix}\)
(看到了吗?这就是完美的下三角矩阵!表示第1个词不能看别人,第2个词只能看第1个,依此类推。如果你需要包含对角线,代码里应该写 >=)
性能优化¶
CV版实现¶
Ascend的硬件复杂,除了vector还有cube,如何利用好两者,写好CV版代码是进阶的要点。