2026-06-07
Triton 矩阵乘法深度拆解:从内存层次到分块策略
矩阵乘法是现代深度学习最核心的算子。从 Transformer 的 Attention 到 MLP 的投影层,底层都是矩阵乘法。如果只学一个 Triton kernel,那一定是 matmul。
本文结合 Triton 官方 Tutorial、知乎上的优秀笔记以及实际踩坑经验,从 GPU 内存层次出发,逐步拆解 Triton matmul 的设计思路。
本文参考了知乎博主「自然卷633」的《Triton 概念与编程入门笔记》。他在文中对指针算术、Group 分组策略的讲解非常清晰,感谢他的分享。
1. Triton 的双重身份
在开始写 kernel 之前,先搞清楚 Triton 到底是什么。
Triton 既是编程语言,也是编译器。作为编程语言,它是一门基于 Python 的 DSL(Domain Specific Language),让你用类 Python 语法编写 GPU kernel。作为编译器,它是 torch.compile 的默认后端,负责将 Torch IR 转化为高效的 GPU 代码。
和 CUDA 的关系可以这样理解:
- CUDA 像是微单相机:调整每一个曝光参数,极限画质但学习曲线陡峭。一个工业级 GEMM 用 CUDA 手写,优化到极致可能需要上千行代码。
- Triton 像是手机相机:拿起就拍,自动对焦,出片基本过关。同样的 matmul,Triton 大约 30 行就能达到 cuBLAS 的 80-90% 性能。
Triton 的编程哲学是 Block-wise Programming:Block 以上的——矩阵怎么分块、块之间怎么调度——归用户管;Block 内部的——线程怎么分配、寄存器怎么用——归编译器自动处理。
2. 为什么需要分块?GPU 内存层次
理解 tiling 之前,必须先理解 GPU 的内存体系。我们以 RTX 4070 Super 为例:
| 内存层级 | 带宽 | 延迟 | 容量 |
|---|---|---|---|
| Global Memory (HBM) | ~450 GB/s | ~400 cycles | 12 GB |
| L2 Cache | ~2 TB/s | ~200 cycles | 48 MB |
| Shared Memory (SRAM) | ~10 TB/s | ~20 cycles | 128 KB/SM |
| Register | ~40 TB/s | ~1 cycle | 256 KB/SM |
朴素 matmul 为什么慢?
对于 C = A @ B,其中 A 是 M×K,B 是 K×N:
for i in range(M):
for j in range(N):
for k in range(K):
C[i,j] += A[i,k] * B[k,j]
每个 A[i,k] 被加载了 N 次,每个 B[k,j] 被加载了 M 次。对于 M=N=K=4096 的 fp32 矩阵,朴素方法需要从 HBM 读取约 512 GB 的数据——但算力只需要 0.001 秒就能完成计算。这是典型的 memory-bound 问题:GPU 的计算单元在等数据。
Tiling 的核心思想
把矩阵切成小块(tile),每次将一个 tile 从慢速的 HBM 搬到快速的 SRAM,在 SRAM 上反复计算,算完再搬下一块:
每个 A 元素只从 HBM 读一次 → 总数据量从 512 GB 降到 ~256 MB
减少了 2000 倍的内存读取
3. Tiled Matmul 算法
Triton 的 Grid 是二维的:(cdiv(M, BLOCK_M), cdiv(N, BLOCK_N))。每个 program 独立计算 C 的一个 [BLOCK_M, BLOCK_N] 输出子块。
pid_m = program_id(0) # 在 M 方向第几个块
pid_n = program_id(1) # 在 N 方向第几个块
rm = pid_m * BLOCK_M + arange(0, BLOCK_M) # [BLOCK_M]
rn = pid_n * BLOCK_N + arange(0, BLOCK_N) # [BLOCK_N]
acc = zeros([BLOCK_M, BLOCK_N], float32)
for k in range(0, K, BLOCK_K):
rk = k + arange(0, BLOCK_K) # [BLOCK_K]
a_tile = A[rm, rk] # [BLOCK_M, BLOCK_K]
b_tile = B[rk, rn] # [BLOCK_K, BLOCK_N]
acc += dot(a_tile, b_tile) # Tensor Core 一次完成
C[rm, rn] = acc
K 维度的循环是唯一需要串行执行的部分——每次处理 BLOCK_K 列/行,用 Tensor Core 一次完成 [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] 的矩阵乘法并累加。
4. 指针算术:从首地址到二维 Tile
Triton kernel 中怎么用指针访问二维矩阵?这是新手最容易困惑的地方。
PyTorch 中,一个张量的数据以一维数组按行连续存储在显存中。tensor.stride(0) 告诉你从第 i 行到第 i+1 行需要跳过多少个元素,tensor.stride(1) 则是列方向上的步长。对于行主序的 M×K 矩阵:
A.stride(0) = K # 沿 M 方向走一步跳过 K 个元素(一整行)
A.stride(1) = 1 # 沿 K 方向走一步跳过 1 个元素
那么 A[rm, rk] 这个二维 tile 的指针范围就是:
# rm: [BLOCK_M] 行索引, rk: [BLOCK_K] 列索引
# 注意 [:, None] 和 [None, :] 的广播方向!
a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
# 结果: shape [BLOCK_M, BLOCK_K] 的地址矩阵
这是整个 kernel 中最精巧的一行。把它拆开来看,广播过程是这样的:
rm 是形状 (BLOCK_M,) 的行索引向量,rk 是形状 (BLOCK_K,) 的列/内积轴索引向量。要得到每对 (m, k) 的地址,需要把它们分别扩展然后广播相加:
rm[:, None] → shape (BLOCK_M, 1) 行方向:每行一个不同的行号
rk[None, :] → shape (1, BLOCK_K) 列方向:每列一个不同的列号
相加广播后:
rm[:, None] * stride_am + rk[None, :] * stride_ak
→ shape (BLOCK_M, BLOCK_K) 每个元素是唯一的 (行,列) 地址偏移
具体来说,假设 BLOCK_M=2, BLOCK_K=3,rm = [10, 11],rk = [4, 5, 6],stride_am=100, stride_ak=1:
rm[:, None] = [[10], rk[None, :] = [[4, 5, 6]]
[11]]
rm[:, None] * 100 = [[1000], rk[None, :] * 1 = [[4, 5, 6]]
[1100]]
相加 → [[1004, 1005, 1006], 即 A[10,4] A[10,5] A[10,6]
[1104, 1105, 1106]] A[11,4] A[11,5] A[11,6]
一步广播得到了整个 tile 的地址矩阵,不需要任何显式循环。
K 维度循环时,只需要在指针上向前移动 BLOCK_K 个元素:
a_ptrs += BLOCK_K * stride_ak # stride_ak = 1 (行主序)
b_ptrs += BLOCK_K * stride_bk # stride_bk = N
不需要重新计算整个地址矩阵——这是性能关键。
5. 进阶:Group 分组与 L2 Cache 优化
基础的 tiled matmul 已经能跑出不错的性能,但离工业级还有距离。Triton 官方 Tutorial 引入了一个重要的优化:Group 分组。
问题出在 L2 Cache。GPU 的所有内存访问都必须经过 L2 Cache。如果按简单的行主序顺序处理 C 的各个子块,会导致相邻 program 加载的数据不能复用 L2 Cache 中的内容。
解决方法:将 C 矩阵沿 M 方向分组,每组包含 GROUP_SIZE_M 个 BLOCK_M 块。组内 program 共享 A 矩阵的同一个区域,提高 L2 命中率。
grid 从 2D 变成 1D: total_programs = cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)
每个 program 自己解码位置:
pid = program_id(0)
num_pid_m = cdiv(M, BLOCK_M)
num_pid_n = cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
以 M=N=K=2048, BLOCK_M=BLOCK_N=128, GROUP_SIZE_M=8 为例:
- Grid 总共 16×16=256 个 program
- 每个 Group 包含 8×16=128 个 program(M 方向 8 个 block,N 方向覆盖全部 16 列)
- 两个 Group 共享 A 矩阵的上半和下半区域
- 这比简单的行主序排序节省了约 40% 的 L2 Cache 数据换出
6. 完整实现骨架
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
# ... Group 解码逻辑(见上节)...
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
rk = k * BLOCK_K + tl.arange(0, BLOCK_K)
a = tl.load(a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak,
mask=rk[None, :] < K, other=0.0)
b = tl.load(b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn,
mask=rk[:, None] < K, other=0.0)
acc = tl.dot(a, b, acc) # Tensor Core
# Store
mask_m = rm[:, None] < M
mask_n = rn[None, :] < N
tl.store(c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn,
acc, mask=mask_m & mask_n)
7. 你的代码 vs 工业级:差距在哪
基础 tiled matmul 在 RTX 4070S 上大约能达到 30-40 TFLOPS。cuBLAS 同硬件可达 100+ TFLOPS。差距来源:
| 优化项 | 基础版 | 工业级 | 大约提升 |
|---|---|---|---|
| Shared Memory 显式管理 | 依赖寄存器/编译器 | 手动分配 SRAM,减少寄存器压力 | 1.5-2x |
| Double Buffering | 串行加载+计算 | 加载下一块的同时计算当前块 | 1.3x |
| Bank Conflict 规避 | 无 | Padding + swizzle 模式 | 1.2x |
| L2 Cache 分组 | 无 | GROUP_SIZE_M 分组 | 1.3x |
| Auto-tuning | 固定 BLOCK 大小 | @triton.autotune 搜索最优配置 |
1.5x |
这些优化不需要一次全搞懂。先跑通基础版本,然后逐项叠加——每一个优化都会让你对 GPU 架构的理解深一层。
8. 常见陷阱速查
- 累加器用 float32:即使输入是 fp16,
acc必须是 float32。tl.dot的 dtype 由累加器决定。 - Mask 只加在必要的维度:A[rm, rk] 中 M 维不会越界(grid 已保证),只需 mask K 维度:
rk[None, :] < K。 - tl.dot 数值误差:Tensor Core 使用 TF32,与 PyTorch fp32 参考有 ~2% 误差。测试时用
atol=5e-2而非1e-4。 - BLOCK_K 不宜过小:太小(如 16)导致加载开销占比高。常见范围:32-128。
[:, None]/[None, :]方向:搞反了会导致 shape mismatch 的编译错误——也是最常见的调试时间杀手。
9. 延伸阅读
- Triton 官方 Matmul Tutorial — 本文核心参考
- 谈谈对 OpenAI Triton 的阅读 — Triton 编译链条解析
- 从 MLIR 理解 Triton — 代码生成流程
- 自然卷633: Triton 概念与编程入门笔记 — 本文重要参考,Group 分组的数值示例非常直观