跳转至

Triton & Triton Ascend

导言

  • Ascend上训练编译成全图有功能问题,导致下发问题并不能像GPU一样完全解决;
  • 在浦江实验室的经验是,triton确实能快速拿到2~3倍的收益,如果算子还有问题就能考虑

核心概念

Grid

基础语法

编译控制

@triton.heuristics(...)`:智能推导(元编程)

一句话解释:根据传入内核的参数,在编译前自动计算出一些静态的布尔标志位(Flags)或常量。

@triton.heuristics({
    'USE_G': lambda args: args['g'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
  • 当你在 Python 里调用这个 Triton Kernel 时,会传入一堆参数(比如张量 g,张量 cu_seqlens 等)。
  • triton.heuristics 会在真正把代码丢给 GPU 编译之前,先运行这几个 lambda 匿名函数。
  • args['g'] is not None:如果你调用时传了 g 张量,USE_G 就会被硬编码为 True;如果没传,就是 False
  • args['cu_seqlens'] is not None:同理,如果传了累积序列长度(通常用于变长序列,Variable Length Attention),IS_VARLEN 就是 True

为什么要这样做?(核心机密:死代码消除)

在 GPU 编程中,在最内层循环写 if 分支是非常影响性能的。 通过 heuristicsUSE_G 变成了一个编译期常量(constexpr。 * 如果 USE_G = False,Triton 编译器在编译时,会直接把代码里所有 if USE_G: 下面的代码全部删掉(Dead Code Elimination, 死代码消除)! * 这样,你只需要写一份代码,就能自动生成“带门控”和“不带门控”两套极其纯净、没有任何 if 开销的 GPU 机器码。

@triton.jit(do_not_specialize=['T']):编译与特化控制

一句话解释:告诉 Triton 编译器:“这是一个 GPU 函数,请帮我编译它。但是,请千万不要对参数 T 进行死板的过度优化!

@triton.jit(do_not_specialize=['T'])

JIT 就是 Just-In-Time(即时编译)。它把 Python 语法翻译成底层的 GPU 汇编指令(PTX 或 SASS)。这是写 Triton 必带的装饰器。

do_not_specialize?(特化控制)是处理动态 Shape 时的救命稻草

  • Triton 的默认行为(疯狂特化):为了让 GPU 跑得最快,Triton 极其激进。如果你传进来的参数 T(比如序列长度 Sequence Length)这次是 1024,Triton 就会针对 1024 这个数字编译出一个专属的二进制程序。

    • 这很好,因为编译器知道确切数字,能做好多极致的底层优化(比如循环展开)。
    • 但是灾难来了:如果下一轮训练,你传进来的 T 变成了 1025,或者 1028。Triton 一看:“哎呀,数字变了,之前的 1024 专属程序不能用了!” 于是它会从头开始重新编译一次代码
    • 编译是非常非常慢的(可能需要零点几秒甚至几秒)。如果你每次输入的序列长度都不一样,模型就会因为疯狂重新编译(Re-compilation Storm)而卡得像 PPT 一样。
  • 加上 do_not_specialize=['T']

    • 你明确告诉了 Triton 编译器:“老哥,T 这个变量是一个普通的动态数字,它随时会变。请把它当成一个普通的寄存器变量,不要把它硬编码(Hardcode)进二进制程序里!
    • 这样,不管接下来你的 T 是 1024 还是 2048,只要数据类型没变,Triton 就会一直复用第一次编译好的那套 GPU 代码,彻底解决了运行过程中的卡顿问题。

Memory & Control

  • tl.cdiv(K, BK) 向上取整除法 (Ceiling Division)。相当于数学上的 \(\lceil K / BK \rceil\),或者 Python 里的 (K + BK - 1) // BK。
  • tl.make_block_ptr(...) 数据的基地址(base),整个大张量的形状(shape),大张量在内存里是怎么跨步的(strides),你要切的块有多大(block_shape),以及当前切块的起始坐标(offsets)。
  • tl.load(ptr, boundary_check=...) 从全局显存(VRAM)加载数据到 GPU 的极速寄存器(SRAM)中。把上面定义好的指针 p_k 对应的数据块加载进来,变成一个二维张量 b_k。boundary_check=(0, 1) 是它的神仙功能:如果你的数据大小不是块大小的整数倍(比如最后一块超出了边界),它会自动帮你把越界的部分填充为 0,再也不用手动写 mask 了!

tl.make_block_ptr

tl.make_block_ptr 出现之前,写 Triton 代码简直就是一场“计算指针偏移量”的噩梦,稍有不慎就会遇到内存越界(Segfault)直接导致程序崩溃。而 tl.make_block_ptr 就像是一个智能的“取景框”,你只需要告诉它大背景是什么样,取景框多大,放在哪,它就会自动帮你把里面的数据完好无损地搬运到显存(SRAM)中。

在 NVIDIA 的 Hopper 架构(比如 H100)上,这个 API 甚至会直接调用硬件级别的 TMA (Tensor Memory Accelerator) 单元,速度极快。

我们把你之前代码里的这行经典调用拿出来,把它“解剖”开:

# K shape [B, T, H, K] 1, 65536, 32, 128
p_k = tl.make_block_ptr(
    base=k + k_batch_off + (bos * H + i_h) * K, 
    shape=(T_local, K), 
    strides=(H * K, 1), 
    offsets=(i_t * BT, i_k * BK), 
    block_shape=(BT, BK), 
    order=(1, 0)
)

这个函数总共有 6 个核心参数。为了方便理解,你可以想象我们有一张巨大的清明上河图(整个大张量),而我们要拿着一个小放大镜(Block)去截图。


  1. base (基地址)

  2. 字面意思:大张量在 GPU 全局内存(HBM)中的物理起始地址。

  3. 工程师大白话“画卷的绝对起点在哪?”
  4. 在代码中k + k_batch_off + (bos * H + i_h) * K

    • 这里通常不仅是整个大矩阵的首地址(k),往往还包含了 batch 偏移量(k_batch_off)和 多头注意力(Multi-Head)的头偏移量((bos * H + i_h) * K)。
    • 你可以理解为:在这个大矩阵中,我们已经把大本营扎在了当前 Batch、当前 Head 的第 0 行第 0 列的位置。
  5. shape (全局形状)

  6. 字面意思:当前你要操作的这个“完整数据张量”的全局大小(通常是二维或多维的 Tuple)。

  7. 工程师大白话“这幅画的总长度和总宽度是多少?”
  8. 在代码中(T_local, K)

    • 告诉 Triton,我们现在面对的是一个总共有 T_local 行(比如序列长度 1024),总共有 K 列(比如特征维度 128)的巨大矩阵。
    • 为什么需要这个? 主要是为了安全。Triton 知道大矩阵的边界后,当你截取到边缘时,它配合 boundary_check 就能自动进行安全填充(Pad 0),防止越界。
  9. strides (跨步/步长)

  10. 字面意思:在内存这一维的线性空间里,你想在这个矩阵中移动 1 行 或 1 列,内存指针需要跳过多少个元素?

  11. 工程师大白话“物理内存是一条直线,怎么折叠成这幅二维画的?”
  12. 在代码中(H * K, 1)

    • 这是一个元组 (stride_row, stride_col)
    • stride_col = 1:在同一行里,向右走 1 格,内存地址加 1。这说明数据在内存里是连续挨着存的(这叫 Row-Major,行优先)。
    • stride_row = H * K:如果我想走到下一行,内存指针要跳过 H * K 个元素(跳过所有头在这个 Token 上的数据)。
  13. offsets (切块偏移量)

  14. 字面意思:我们要取的这一个小 Block,它的左上角在整个大矩阵里的相对坐标是多少。

  15. 工程师大白话“放大镜的左上角对准画里的哪个坐标?”
  16. 在代码中(i_t * BT, i_k * BK)

    • 由于外面有 for 循环,坐标是动态的。
    • 行坐标:当前处理的 token 块索引 i_t 乘以块大小 BT(比如第 2 块的起始行就是 2 * 64 = 128)。
    • 列坐标:当前处理的维度块索引 i_k 乘以维度大小 BK(比如第 1 块的起始列就是 1 * 32 = 32)。
  17. block_shape (块形状)

  18. 字面意思:我们要从显存里一口气取进寄存器(或共享内存 SRAM)的小数据块的大小。

  19. 工程师大白话“放大镜的框到底有多大?”
  20. 在代码中(BT, BK)

    • 取出一个 BT 行、BK 列的小矩阵(比如 64 x 32)。
    • 注意block_shape 里的值必须是 2 的幂次方(16, 32, 64, 128...)并且在编译时就要确定,不能是动态变量。
  21. order (连续性顺序 / 内存读取模式)

  22. 字面意思:定义这个 Block 里的数据在物理内存上,哪个维度的元素是紧紧挨在一起的(连续的)。

  23. 工程师大白话“GPU 加载数据时,最内层的高速循环应该沿着哪个方向跑?”
  24. 在代码中(1, 0)
    • 这是一个非常影响性能(Performance)的参数,决定了内存的“合并访问(Memory Coalescing)”效率。
    • 数字代表维度的索引。对于 2D 张量,0 是行维度,1 是列维度。
    • order=(1, 0) 的意思是:第 1 维(列维度)变化最快,第 0 维(行维度)变化最慢。换句话说,数据是按行读取的(即每行的数据在物理上挨在一起)。
    • 反过来,如果数据在内存里是列优先存储的(比如经过转置的矩阵),你需要传入 (0, 1),以保证 GPU 一次性能“吞”下连续的内存块,而不是像跳蚤一样在内存里乱跳。

一张图总结

把它们拼接在一起,make_block_ptr 就在 GPU 后台做了一件这样的事:

[ 全局显存 HBM ]
Base: 地址起点
+----------------------------------------+
| 大矩阵 Shape: (T_local, K)             |
| 步长 Strides: (H*K, 1)                 |
|                                        |
|        [Offsets: (i_t*BT, i_k*BK)]     |
|           +---------------+            |
|           |               |            |
|           | 小块 (BT, BK) | <-- Order: |
|           | (Block_shape) |   (1,0)横向|
|           +---------------+            |
|                                        |
+----------------------------------------+
配合上下一行的 tl.load(p_k, boundary_check=(0, 1)): * 如果:你的框(放大镜)跑到了画卷的外面(比如最后一块的尺寸不够 BTBK 了)。 * 结果:Triton 会因为看到了 boundary_check=(0, 1),自动查阅 shapeoffsets,把超出的部分智能用 0.0 填补,这就是 make_block_ptr 被誉为神级 API 的原因。再也不用手写繁琐的越界 Mask 了!

Math & Matrix

  • tl.trans(b_k) 矩阵转置 (Transpose)。
  • tl.dot(b_k, tl.trans(b_k)) 矩阵乘法

Indexing & Masking

  • tl.arange(start, end) 的语法作用与 NumPy (np.arange) 或 PyTorch (torch.arange) 非常相似:它用于生成一个包含连续整数的 一维 Tensor(张量)。
  • [:, None] 和 [None, :]
    • row_indices = tl.arange(0, BT)[:, None] 会把 [0,1,2] 变成一个列向量(3行1列)。
    • col_indices = tl.arange(0, BT)[None, :] 会把它变成一个行向量(1行3列)。

tril_mask = (row_indices > col_indices).to(tl.float32)

Shape 变化(广播发生):这里执行了一个比较运算:[4, 1] 的矩阵 > [1, 4] 的矩阵。 神器的底层逻辑:Triton(或PyTorch)发现形状对不上,但一边是 4x1,另一边是 1x4,符合广播规则。于是它把两者都“拉长”成 [4, 4]。

row_indices 被横向复制了 4 次:

\(\begin{bmatrix} 0&0&0&0 \\ 1&1&1&1 \\ 2&2&2&2 \\ 3&3&3&3 \end{bmatrix}\)

col_indices 被纵向复制了 4 次:

\(\begin{bmatrix} 0&1&2&3 \\ 0&1&2&3 \\ 0&1&2&3 \\ 0&1&2&3 \end{bmatrix}\)

两两做 > 对比:

第0行:0>0(False), 0>1(False), 0>2(False)... 第1行:1>0(True), 1>1(False), 1>2(False)... 第2行:2>0(True), 2>1(True), 2>2(False)...

最终 tril_mask 的形状是 [4, 4](即 [BT, BT]),内容长这样(True转成了1.0,False转成了0.0):

\(\begin{bmatrix} 0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 \end{bmatrix}\)

(看到了吗?这就是完美的下三角矩阵!表示第1个词不能看别人,第2个词只能看第1个,依此类推。如果你需要包含对角线,代码里应该写 >=)

性能优化

CV版实现

Ascend的硬件复杂,除了vector还有cube,如何利用好两者,写好CV版代码是进阶的要点。