2026-06-11

GPU Kernel 的 Tile、Grid、Wave 与 Launch Overhead

写 Triton kernel 的时候,BLOCK_MBLOCK_NBLOCK_K 这些参数到底怎么影响性能?为什么 autotune 给不同 shape 选不同的 tile 配置?

这篇文章从 CTA、grid、wave、launch overhead 讲起,到三个 shape 的 autotune 决策规律。

从工人到 CTA

每个 tl.program_id(0), tl.program_id(1) 对应一个 CTA(Cooperative Thread Array)。可以把它理解为一个"工人"——一个人拿一块砖:

# grid 决定一共派多少个工人
grid = (cdiv(M, BLOCK_M), cdiv(N, BLOCK_N))

# BLOCK_M × BLOCK_N = 每个工人负责的输出子矩阵

一个 matmul C = A @ B,把 C 切成 BLOCK_M × BLOCK_N 的小块,每个工人负责一块。工人总数是 grid[0] × grid[1]

SM:GPU 的车间

工人不是凭空干活的。他们在 SM(Streaming Multiprocessor)上跑。

RTX 4070 SUPER: 56 个 SM
每个 SM 同时容纳的工人数量 = 受限于寄存器 + 共享内存

我们的 matmul kernel 每个工人约用 128 个寄存器。一个 SM 有 65536 个寄存器,但一个工人只能用 255 个(硬件上限)。

实际上一个 SM 同时能跑的工人数由 寄存器压力 决定:

每个工人 128 regs × 每个工人 4 warps × 32 threads = 16384 regs per worker
SM 有 65536 regs → 65536 / 16384 = 4 个工人/SM ✓

如果每个工人 200 regs: 200 × 4 × 32 = 25600 → 65536 / 25600 = 2 个工人/SM

所以 num_warps 和 tile 尺寸直接影响 occupancy(每 SM 同时跑的工人数)。

Wave:工人排队

如果工人的总数超过 SM 能同时容纳的量,超出的就得排队。wave 就是排队的轮数。

wave 数 = 工人总数 / (SM 数量 × 每 SM 同时工人数)

举个例子:

1024×1024 matmul, BLOCK_M=64, BLOCK_N=64
  工人 = 16 × 16 = 256 个
  每 SM 同时 4 个工人
  256 / (56 × 4) = 256 / 224 ≈ 1.1 wave  ← 几乎不用排队

2048×2048, BLOCK_M=64, BLOCK_N=64
  工人 = 32 × 32 = 1024 个
  1024 / 224 ≈ 4.6 wave  ← 排 5 轮

4096×4096, BLOCK_M=64, BLOCK_N=64
  工人 = 64 × 64 = 4096 个
  4096 / 224 ≈ 18.3 wave  ← 排 18 轮!

Launch Overhead:每次排队的代价

每个 wave 结束到下一个 wave 开始之间,GPU 调度器要:

  • 释放上一批工人的寄存器
  • 分配新一批工人的寄存器和共享内存
  • 加载新工人的 kernel arguments
  • 启动新工人执行

每次切换的时间是 固定的几十微秒。wave 越多,这部分开销累积越大。

在 4096² 的 18 个 wave 场景下,调度开销可以达到毫秒级。而 matmul 本身才 2ms——调度占了一小半。

三种优化:不同 Shape 不同策略

回到 BLOCK_MBLOCK_NBLOCK_K 的选择。

BLOCK_M 和 BLOCK_N 越大 → 工人越少 → wave 越少 → launch overhead 越低。

但 tile 不能无限大——受寄存器限制了。而且 tile 太大也有问题(对小的 M/N 维度会浪费 mask 计算)。

BLOCK_K 越大 → K 循环次数越少 → 从 global memory 加载的次数越少。

但 BLOCK_K 大意味着每个工人用更多共享内存(存 A tile 和 B tile),会限制 occupancy。

这些因素互相制约,最优配置取决于矩阵的 绝对大小(不只是形状)。

1024² — 工人太少,需要 wide tile

BLOCK_N=64:  256 工人, 1.1 wave
BLOCK_N=128: 128 工人, 0.6 wave ← 更少!但 wave 已经不是瓶颈

为什么还选 N=128?因为 wide tile 的算存比更高:

BLOCK_N=64:  (64+32) bytes load → 64×32×2 flops → 21 flops/byte
BLOCK_N=128: (128+32) bytes load → 128×32×2 flops → 26 flops/byte

小矩阵没有足够多工人填满 SM,也谈不上 launch overhead。瓶颈是 memory bandwidth——每次从 global memory 搬的 byte 能做的计算越多越好。

小矩阵的策略:提算存比。

2048² — 工人适中,deep K 减内存流量

BLOCK_K=32: 1024 工人, 4.6 wave, K 循环 2048/32 = 64 次
BLOCK_K=64: 1024 工人, 4.6 wave, K 循环 2048/64 = 32 次 ← 少跑 32 趟!

1024 个工人 × 每人 64 次 K 迭代 × 每次迭代加载 A tile + B tile × 2 bytes(fp16):

K=32: 1024 × 64 × (64×32 + 32×64) × 2 bytes ≈ 16.8 GB global memory traffic
K=64: 1024 × 32 × (64×64 + 64×64) × 2 bytes ≈ 16.8 GB  ← 一样!

等等,总 traffic 一样?对——问题不在总量,在每次 load 的效率和延迟

K=32: 每 tile 只有 64×32 个元素 = 2KB → 每次 load 很小,512 次 load/L1 请求
      加载延迟摊销在 2048 flops 上
K=64: 每 tile 有 64×64 个元素 = 4KB → 每次 load 更大,256 次 load/L1 请求
      加载延迟摊销在 4096 flops 上

更大的 tile 让每次 load 请求的延迟被更多计算覆盖。这是 latency hiding——GPU 的核心设计理念:用大量计算掩盖内存延迟。

中矩阵的策略:deep K 掩盖内存延迟。

4096² — 工人太多,必须降调度开销

BLOCK_N=64:  4096 工人, 18.3 wave → 调度开销 ~1ms+
BLOCK_N=128: 2048 工人, 9.1 wave  → 调度开销砍半

4096 个工人的调度开销已经超过 K 循环优化能带来的任何收益。工人"够多了"——多到成了一种负担。此时减少工人数量是第一优先级。

大矩阵的策略:减少工人,降调度开销。

一张图总结

矩阵规模   主导瓶颈         优化方向        autotune 选了什么
────────  ───────────────  ─────────────  ──────────────────
 1024²    memory BW        wide tile       N=128 (算存比 ↑)
 2048²    memory latency   deep K          K=64  (load 次数 ↓)
 4096²    launch overhead  fewer CTAs      N=128 (wave 数 ↓)

三个规模对应三种瓶颈,没有一种 tile 配置能通吃。

跟 CPU 编程的类比

如果你有 CPU 多线程背景:

  • CTA ≈ 一个线程(但有 128 个硬件线程在里面做 SIMT)
  • SM ≈ CPU 核心(SM 能"超卖"运行多个 CTA)
  • wave ≈ CPU 线程池的任务队列长度(队列短 → 核闲着;队列太长 → 调度开销大)
  • BLOCK_K ≈ 循环展开因子(展开多 → 循环次数少,但寄存器压力大)
  • BLOCK_M/ BLOCK_N ≈ 任务粒度(粗粒度 → 任务少;细粒度 → 任务多但管理成本高)

GPU 和 CPU 都面临同一个根本问题:如何平衡任务数量、每任务工作量、资源占用。GPU 的 twist 是:它是一个大规模并行机器,队列管理成本在极端 worker 数量下不可忽略。

检验理解

试着预测:给一个 M=512, N=512, K=512 的矩阵,最优配置大概是怎样的?然后再看 M=8192, N=8192, K=8192

答案在文章末尾。


512³:工人 = 512/64 × 512/64 = 64 个。太少了——物理 SM 都填不满(56 SM × 4 workers = 需要 224+ 工人才饱和)。此时瓶颈极端是 memory bandwidth,最优配置可能是 BLOCK_N=256(极 wide tile)。甚至 BLOCK_M 也可能需要拉大——因为 M 太小了,grid 只有 8 个 CTA 沿 M。

8192³:工人 = 128×128 = 16384 个。16384 / 224 ≈ 73 wave。调度开销已到荒谬级别。最优配置一定是 BLOCK_M=256, BLOCK_N=256,把工人压到 32×32=1024 个(≈ 4.6 wave)。