2026-06-07
FlashAttention:为什么它改变了 Transformer 推理的游戏规则
Transformer 的 Self-Attention 是整个架构的核心,也是最大的显存黑洞。FlashAttention 把 Attention 的整个计算链融合进一个 GPU kernel,把 O(n²) 的显存开销降到 O(n)。这是 Triton 生态中最著名的算法,也是理解「为什么要写 fused kernel」的最佳案例。
本文从标准 Attention 的显存瓶颈出发,拆解 FlashAttention 的 tiling + online softmax 设计,最后给出 Triton 实现的完整骨架。
1. 标准 Attention 的显存灾难
无论 Self-Attention 还是 Cross-Attention,核心计算都一样:
S = Q @ K^T # [B, H, seq_len, seq_len]
P = softmax(S)
O = P @ V
问题出在 S 矩阵。它的大小是 O(n²),并且是整个计算链的必经中间产物:
| seq_len | 单个 head 的 S 矩阵 (fp16) | B=1, H=16 的总量 |
|---|---|---|
| 1024 | 2 MB | 32 MB |
| 4096 | 32 MB | 512 MB |
| 8192 | 128 MB | 2 GB |
| 32768 | 2 GB | 32 GB |
RTX 4070S 只有 12 GB 显存。seq_len=8192 时,仅 S 矩阵就吃掉 2 GB——更不用说还要存 Q、K、V、模型参数、KV Cache。实际推理中,seq_len 超过 4096 就开始吃力。
本质问题:Attention 是 memory-bound,不是 compute-bound。GPU 的计算单元大部分时间在等数据。
2. FlashAttention 的核心思路
FlashAttention 的洞察非常简单:我们不需要存 S 矩阵。
S 矩阵的唯一作用是计算 softmax 的分子(exp(S))和分母(sum(exp(S)))。如果我们能在计算 S 的每一块时,当场完成 softmax 并累加到输出 O,那 S 矩阵就不需要保存。
具体做法:
- Tiling:把 Q 沿序列维度切成大小为 BLOCK_M 的块,把 K/V 切成大小为 BLOCK_N 的块
- Online Softmax:对每个 Q 块,逐块处理 K/V,维护 running max 和 running denominator
- Fused Write:算完一个 Q 块的所有 K/V 块后,一次写回输出 O
伪代码:
对于每个 Q 块 [BLOCK_M, head_dim](外层循环):
m = [-inf] # running max,每个 query 行一个
d = [0] # running denominator
acc = 0 # 累加输出 [BLOCK_M, head_dim]
对于每个 K,V 块 [BLOCK_N, head_dim](内层循环):
S_j = Q_block @ K_block^T * scale # [BLOCK_M, BLOCK_N]
m_new = max(m, max(S_j, axis=1)) # 更新 max
P = exp(S_j - m_new) # 稳定化
d_new = d * exp(m - m_new) + sum(P) # 更新分母
acc = acc * (d/d_new) * exp(m-m_new) + P @ V_block / d_new
m, d = m_new, d_new
存储 acc → output
S_j(形状 [BLOCK_M, BLOCK_N])在每一步都在寄存器/L1 里计算和消费,从未写入显存。这就是 "Flash" 的含义——一闪而过。
而 online softmax 的数学保证:无论中间累加了多少块,最终结果和一次性算完整 softmax 完全等价。
3. Grid 布局:1D self-decoding
和矩阵乘法的 2D grid 不同,FlashAttention 用 1D grid + 自解码:
grid = (cdiv(seq_len, BLOCK_M) * B * H,)
# 每个 program 自己解码:(batch, head, block_m_start)
pid = program_id(0)
num_m_blocks = cdiv(seq_len, BLOCK_M)
bh_id = pid // num_m_blocks # 第几个 (batch, head) 对
block_m = pid % num_m_blocks # Q 的第几个行块
b_idx = bh_id // H # 解码 batch id
h_idx = bh_id % H # 解码 head id
例如 seq_len=256, BLOCK_M=64, B=2, H=4:
- num_m_blocks = 4
- 总 programs = 4 × 2 × 4 = 32
- pid=10 → bh_id=2, block_m=2 → 第0个batch的第2个head,Q 的第2个行块
4. Triton 实现骨架
@triton.jit
def flash_attention_fwd_kernel(
q_ptr, k_ptr, v_ptr, o_ptr,
stride_qb, stride_qh, stride_qm, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vk, stride_vd,
stride_ob, stride_oh, stride_om, stride_od,
B, H, seq_len, head_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
sm_scale: tl.constexpr,
):
# 1. 解码 pid
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(seq_len, BLOCK_M)
bh_id = pid // num_m_blocks
block_m = pid % num_m_blocks
b_idx = bh_id // H
h_idx = bh_id % H
# 2. 偏移量和 Q block 加载
offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BLOCK_M]
offs_n = tl.arange(0, BLOCK_N) # [BLOCK_N]
offs_d = tl.arange(0, head_dim) # [head_dim]
q = tl.load(q_ptr + (b_idx * stride_qb + h_idx * stride_qh
+ block_m * BLOCK_M * stride_qm)
+ offs_m[:, None] * stride_qm
+ offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < seq_len, other=0.0)
# 3. Online softmax 初始化
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
# 4. 循环 K/V 块
for block_n in range(0, tl.cdiv(seq_len, BLOCK_N)):
# 加载 K block
k = tl.load(k_ptr + (b_idx * stride_kb + h_idx * stride_kh
+ block_n * BLOCK_N * stride_kn)
+ offs_n[:, None] * stride_kn
+ offs_d[None, :] * stride_kd,
mask=offs_n[:, None] < seq_len, other=0.0)
# S = Q @ K^T * scale
s = tl.dot(q, tl.trans(k)) * sm_scale # [BLOCK_M, BLOCK_N]
# Online softmax update
m_new = tl.maximum(m_i, tl.max(s, axis=1))
alpha = tl.exp(m_i - m_new)
p = tl.exp(s - m_new[:, None])
d_new = alpha * d_i + tl.sum(p, axis=1)
# 加载 V block + 累加输出
v = tl.load(v_ptr + (b_idx * stride_vb + h_idx * stride_vh
+ block_n * BLOCK_N * stride_vk)
+ offs_n[:, None] * stride_vk
+ offs_d[None, :] * stride_vd,
mask=offs_n[:, None] < seq_len, other=0.0)
acc = acc * alpha[:, None] * (d_i / d_new)[:, None]
acc += tl.dot(p, v) / d_new[:, None]
m_i = m_new
d_i = d_new
# 5. 写回输出
tl.store(o_ptr + (b_idx * stride_ob + h_idx * stride_oh
+ block_m * BLOCK_M * stride_om)
+ offs_m[:, None] * stride_om
+ offs_d[None, :] * stride_od,
acc, mask=offs_m[:, None] < seq_len)
5. 常见陷阱
- K 要转置:
S = Q @ K^T,用tl.trans(k)而不是.T - dtype 对齐:Q/K/V 是 fp16,
acc必须是 float32。tl.dot(Q, K)要求两个操作数同 dtype - acc 更新顺序:先 rescale 旧值,再加新贡献——写反了就全错
- Mask 在正确的维度:Q 的 mask 是
offs_m[:, None] < seq_len,K/V 是offs_n[:, None] < seq_len - 省略 logsumexp:生产代码需要存 LSE 给 backward 用。仅 forward 时可以跳过
6. v1 → v2 → v3 演进
| 版本 | 年份 | 关键改进 | 硬件要求 |
|---|---|---|---|
| FlashAttention v1 | 2022 | Tiling + online softmax + kernel fusion | 任意 GPU |
| FlashAttention v2 | 2023 | 更好的并行策略,减少非 matmul 操作 | Ampere+ |
| FlashAttention v3 | 2024 | TMA 异步拷贝、FP8、Hopper 专属优化 | H100 以上 |
我们实现的是 v1 的 forward pass——架构最清晰,最适合理解核心思想。