2026-06-08
Triton 进阶路线图:LayerNorm、Fused MLP 与 Autotune
Triton 基础算子写完之后,下一步往哪走?本文梳理三个进阶方向,构成从"能写 kernel"到"能写生产级 kernel"的桥梁。
当前进度回顾
假设你已经完成了:
- Element-wise(ReLU / GELU / SiLU):理解 memory coalescing、memory-bound 的含义
- Softmax:online stable 算法、跨线程 reduction
- Matmul:tiling、
tl.dot(Tensor Core)、grid 设计 - FlashAttention:fused kernel、online softmax 在 attention 中的应用
- Profiling:GPU event timing、roofline model、arithmetic intensity 分析
这六个 exercise 已经覆盖了 Triton 的核心原语。接下来三个方向各有侧重:
| 方向 | 难度 | 学到什么 | 实际价值 |
|---|---|---|---|
| LayerNorm / RMSNorm | ⭐⭐ | 跨 block reduction、Welford 算法 | 每个 Transformer block 至少两次 |
| Autotune | ⭐⭐ | Config 空间搜索、寄存器压力、occupancy | 所有 kernel 的最终优化手段 |
| Fused MLP | ⭐⭐⭐ | Kernel fusion、tile 间依赖管理 | 端到端推理延迟的直接优化 |
方向一:LayerNorm & RMSNorm
为什么重要
一个 7B 模型的 transformer block 里,LayerNorm 只占约 5-15% 的算力,但它无处不在——每个 attention 和 MLP 前后都有 normalization。在长序列推理中(如 128K context),norm kernel 的调用次数可能超过 attention 本身。
算法本质
LayerNorm(x) = (x - mean(x)) / sqrt(var(x) + ε) * γ + β
RMSNorm(x) = x / sqrt(mean(x²) + ε) * γ # 无 β,无减均值
两者都是逐行独立的——每行(每个 token)独立计算统计量,天然适合 GPU 并行。
核心挑战:跨 Block 的 Reduction
如果一行数据(hidden_dim,通常 768-8192)大于一个 BLOCK,就需要跨 block 合并统计量。Welford 算法可以在单次遍历中同时计算均值和方差,但跨 block 合并需要额外的 kernel launch:
Kernel 1: 每个 block 计算自己的 (count, mean, M2),存到临时 buffer
Kernel 2: 合并临时 buffer,计算全局统计量
Kernel 3: 用全局统计量做 normalize + affine
不过大多数 LLM 的 hidden_dim ≤ 1024(如 LLaMA-7B 的 4096 通过 tensor parallelism 分片后),一行可以放进一个 block,单 kernel 搞定。
关键差异
- Welford 并行合并公式:需要正确处理两组统计量的合并,和简单求和不同
- RMSNorm vs LayerNorm:少一次 mean reduction,约快 10-15%
- Affine 参数:γ 和 β 存在 global memory 里,每次都要 load——可以 fuse 进 kernel
方向二:Autotune
为什么手选参数不够
写完 matmul 后,你手动选了 BLOCK_M=64, BLOCK_N=64, BLOCK_K=32。这组参数对特定形状可能不错,但换一个 M×N×K 组合就不一定了。
Triton 的 @triton.autotune 可以自动搜索参数空间,找到每个输入形状的最优配置。
工作原理
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=8),
# ... 更多 config
],
key=['M', 'N', 'K'], # 缓存的键
)
@triton.jit
def matmul_kernel(...):
...
首次运行时,Triton 编译所有 config,逐个 benchmark,缓存最优结果。后续调用直接使用缓存。
参数空间的设计直觉
- BLOCK_M × BLOCK_N:越大 → 每次 load 做更多计算 → 更高 AI。但受限于寄存器(~256KB / SM for fp32 accumulators)
- num_warps:4 warps × 4 CTAs = 16 warps/SM。8 warps × 2 CTAs = 也是 16 warps/SM。但 8 warps 的 block 更大 → 更多寄存器 per block → 可能降低 occupancy
- num_stages:pipeline stages。2 → 更少 shared memory 占用。4 → 更好隐藏访存延迟,但占更多 shared memory
Profiling 与 Autotune 的关系
Profiling 告诉你"哪里慢了"(memory-bound vs compute-bound),Autotune 帮你"自动修"。两者配合使用才是完整的优化流程。
方向三:Fused MLP
为什么需要 fusion
标准的 Transformer MLP:
x = Linear_1(input) # [N, D] → [N, 4D] ← 写回 global memory(慢)
x = GELU(x) # [N, 4D] ← 读 + 写
x = Linear_2(x) # [N, 4D] → [N, D] ← 读 + 写
每一步的中间结果都写回 global memory(GDDR6X 504 GB/s),但 GPU 的 L1 cache 和 shared memory 速度是它的 10 倍(5 TB/s)。
Fusion 的思路:在寄存器 / shared memory 里完成全部计算,只写一次最终结果。
Fused: load → matmul_1 → GELU → matmul_2 → store
↑_____只走一次 global memory 来回_____↑
数据流设计
关键的约束是寄存器容量。中间值 acc1 是 [BLOCK_M, 4K],以 4K = 16384(LLaMA-7B 的 intermediate size)为例,64 行 × 16384 列 × 4 字节 = 4 MB,远超寄存器容量。
解决方法是把中间维度也分 tile:
for k4_block in range(0, 4*K, BLOCK_K4):
# Step 1: acc1 = x_tile @ W1_chunk → [BLOCK_M, BLOCK_K4]
# Step 2: acc1 = GELU(acc1) → 就地修改,寄存器内
# Step 3: acc2 += acc1 @ W2_chunk → 立即消费掉 acc1
这是一个 tiled-gemm → elementwise → tiled-gemm 的流水线。W1 和 W2 的加载顺序需要对齐,确保每次只加载当前 chunk 需要的权重。
实际收益
以 LLaMA-7B 为例,一次 MLP forward 需要读写:
- 未 fusion:x(2MB) + x1(8MB write + 8MB read) + x2(2MB write) = 20 MB
- Fusion 后:x(2MB read) + x2(2MB write) = 4 MB
节省 16 MB 的 memory traffic。在 batch inference 中,这直接降低延迟。
建议学习顺序
Autotune 先做——这是工具技能,给现有的 matmul 加上 autotune。学会后,给 LayerNorm 和 Fused MLP 加 autotune 就是顺手的事。运行
TRITON_PRINT_AUTOTUNING=1看所有 config 的 benchmark 结果,理解为什么某些 config 赢、某些输。LayerNorm / RMSNorm——相对简单,跨 block reduction 是唯一难点。先实现"一行能放进一个 block"的简单版本(
D ≤ 1024),跑通后再扩展到更大的D。Fused MLP 最后——最复杂,涉及 tile 间依赖和流水线管理。建议先用 PyTorch 写 reference(确保算法正确),再逐步翻译成 Triton。
资源
所有练习代码和理论笔记在 github.com/freezetheflame/triton-benchmark-prep:
exercises/
06_autotune.py # Autotune 骨架,待你填完
notes/
layernorm-rmsnorm.md # LayerNorm/RMSNorm 理论 + Welford 算法
fused-mlp.md # Fused MLP 设计 + 伪代码
每个 note 都包含算法推导、伪代码、自测框架和常见陷阱——不提供完整实现,只给你搭好脚手架。