2026-06-29
多头注意力完全图解:MHA/MQA/GQA/MLA的头部关系与数量约束
如果你用过 LLaMA、Mistral、Qwen 或者 DeepSeek-V2,肯定在它们的 config.json 里见过 num_attention_heads 和 num_key_value_heads 这两个参数。前者等于 32 或 64,后者可能是 8、16 或者跟前者一样——这不是随便选的数字,这两者的关系直接决定了推理时的 KV Cache 大小和加速比。
本文用三张 Excalidraw 图解,把 MHA(标准多头注意力)、MQA(多查询注意力)、GQA(分组查询注意力)和 MLA(多头潜在注意力)的头部数量关系、整除约束和数值例子讲清楚。
📐 引子:注意力「头」到底是什么?
在 Transformer 中,每个 attention head 是一个独立的"视角":
其中:
- **Q (Query)**:当前 token 发出的"我关注谁"
- **K (Key)**:所有 token 的"我是谁"
- **V (Value)**:所有 token 的"我有什么"
多个 attention head 的结果拼接后通过输出投影融合:
所有变体(MHA/MQA/GQA/MLA)的区别,本质上是 Q、K、V 的头数关系不同。
🧩 第一章:标准多头注意力 (MHA)
基本概念
MHA(Multi-Head Attention)是 Transformer 原始论文中提出的方案。每个 attention head 拥有独立的 Q、K、V 投影矩阵。
数量关系
| 变量 | 符号 | MHA 中的值 |
|---|---|---|
| Q 头数 | ||
| K 头数 | ||
| V 头数 | ||
| 总注意力头数 | 模型配置参数 |
核心关系:
参数规模
每个 head 的维度 ,投影矩阵:
一个数值例子:d_model=4096, H=32, d_h=128
- 每个 head 的 Q/K/V 维度:128
- 总参数:3 × (4096 × 4096) ≈ 50M 参数(仅注意力投影层)
- KV Cache 大小(每层):2 × H × d_h × seq_len × 精度字节 = 2 × 32 × 128 × seq_len × 2 = 16,384 × seq_len 字节
推理瓶颈
MHA 每生成一个 token 都需要缓存所有 H 个头的 K 和 V 矩阵。当 batch size 和序列长度增加时,KV Cache 占用线性增长,成为推理时的主要显存瓶颈。
🔗 第二章:多查询注意力 (MQA)
动机
MQA(Multi-Query Attention)由 Shazeer 在 2019 年提出,核心观察是:解码时 K/V 缓存的带宽压力远大于计算压力。如果能减少 K/V 头数,KV Cache 就能大幅缩减。
数量关系
MQA 将所有 Q head 映射到同一个 K head 和同一个 V head:
| 变量 | MQA 中的值 | 与 MHA 对比 |
|---|---|---|
| 不变 | ||
| 1 | 减少到 1/H | |
| 1 | 减少到 1/H |
KV Cache 节省
MQA 的 KV Cache 大小只相当于 MHA 的 1/H:
MHA KV Cache / 层 = 2 × H × d_h × seq_len × 2 bytes (FP16)
MQA KV Cache / 层 = 2 × 1 × d_h × seq_len × 2 bytes
以 H=32 为例,MQA 的 KV Cache 仅为 MHA 的 **3.125%**。
代价
共享 K/V 会降低模型的表达能力。在大模型场景下,MQA 的质量损失显著——尤其在不同 head 需要关注不同语义特征时,强制所有 head 共享同一个 K/V 过于激进。
👥 第三章:分组查询注意力 (GQA)
GQA 的折中方案
GQA(Grouped Query Attention)由 Ainslie 等人 2023 年提出,是 MHA 和 MQA 的中间地带。
数量关系
| 变量 | GQA 中的值 | 说明 |
|---|---|---|
| 不变 | ||
| 分组数 (1 ≤ G ≤ H) | ||
| 与 K 相同 |
核心约束:,即 Q 头数必须能被分组数整除。
每组的 Q 头数:
边界情况
| G 的值 | 等价于 | 说明 |
|---|---|---|
| G = 1 | MQA | 所有 Q 共享 1 个 K/V |
| G = H | MHA | 每个 Q 有独立的 K/V |
| G = 2, 4, 8 | GQA | 折中方案,主流选择 |
工业界实践
| 模型 | d_model | (G) | 每组 Q 头数 | KV Cache 节省 | |
|---|---|---|---|---|---|
| LLaMA 1 65B | 8192 | 64 | 64 | 1 (MHA) | 1× |
| LLaMA 2 7B | 4096 | 32 | 32 | 1 (MHA) | 1× |
| LLaMA 2 70B | 8192 | 64 | 8 | 8 | 8× |
| LLaMA 3 8B | 4096 | 32 | 8 | 4 | 4× |
| LLaMA 3 70B | 8192 | 64 | 8 | 8 | 8× |
| LLaMA 3.1 405B | 16384 | 128 | 8 | 16 | 16× |
| Mistral 7B | 4096 | 32 | 8 | 4 | 4× |
| Mistral Large (2) | 12288 | 96 | 8 | 12 | 12× |
| Mixtral 8×7B | 4096 | 32 | 8 | 4 | 4× |
| Qwen2 72B | 8192 | 64 | 16 | 4 | 4× |
| Gemma 2 9B | 3584 | 16 | 8 | 2 | 2× |
| Gemma 2 27B | 4608 | 32 | 16 | 2 | 2× |
可以看到,G=8 是目前最流行的配置——LLaMA 70B、Mistral、Mixtral 都选了它。LLaMA 3.1 405B 进一步把 G 降到 8(H=128→每组 16 个 Q 头),以照顾超大规模推理的 KV Cache 开销。
为什么是 G=8 最流行?
这是一个经验权衡:
G 越小 → KV Cache 越小 → 推理越快 → 但质量越差
G 越大 → 质量越好 → 但 KV Cache 越大 → 推理越慢
G=8 在质量和速度之间找到了一个较好的平衡点。
🔢 第四章:数值约束详解
4.1 整除约束
硬性要求:
为什么?因为注意力计算中,Q 需要按组与 K 做带 batch 维度的矩阵乘法:
def gqa_attention(Q, K, V, H_q, H_kv):
# Q shape: [batch, seq, H_q * d_h]
# K shape: [batch, seq, H_kv * d_h]
# 将 Q 重新排列成 [batch, seq, G, h_per_group, d_h]
# 其中 G = H_kv, h_per_group = H_q / H_kv
Q = Q.view(batch, seq, H_kv, H_q // H_kv, d_h)
K = K.view(batch, seq, H_kv, 1, d_h) # 广播
# batched matmul:
scores = torch.matmul(Q, K.transpose(-2, -1))
# 这里要求 H_q // H_kv 是整数
如果 不是整数,就会产生"不均匀分组"——有些组多一个 Q head,有些组少一个。这会破坏计算图的规则性,导致无法用统一的 batch matmul 实现。
4.2 Head 维度约束
每个 attention head 的维度:
常见取值:64、96、128。
对 GQA 来说,K/V head 的维度也是 ,只是数量减少到 :
- K 的总维度:
- V 的总维度:同上
4.3 KV Cache 大小的精确公式
def kv_cache_size(
num_layers: int,
batch_size: int,
seq_len: int,
d_model: int,
H_q: int,
H_kv: int,
precision_bytes: int = 2, # FP16
) -> int:
"""返回 KV Cache 总大小(字节)"""
d_h = d_model // H_q
kv_size_per_layer = 2 * H_kv * d_h * seq_len * precision_bytes # K + V
return num_layers * batch_size * kv_size_per_layer
一个具体例子:LLaMA 3 8B 推理,batch=4, seq=8192
d_model = 4096, H_q = 32, H_kv = 8
d_h = 4096 / 32 = 128
num_layers = 32
每层 KV Cache = 2 × 8 × 128 × 8192 × 2 = 33,554,432 字节 ≈ 32 MB
总 KV Cache = 32 × 4 × 32 MB ≈ 4 GB
如果换成 MHA(H_kv=32),总 KV Cache 将变成 16 GB——4 倍差距。
4.4 节省率公式
| 机制 | 节省倍数 | ||
|---|---|---|---|
| MHA | 32 | 32 | 1× (基准) |
| GQA (G=8) | 32 | 8 | 4× |
| GQA (G=4) | 32 | 4 | 8× |
| MQA | 32 | 1 | 32× |
🧪 第五章:多头潜在注意力 (MLA)
DeepSeek 的创新
MLA(Multi-head Latent Attention)是 DeepSeek-V2 提出的 KV Cache 压缩方案。它的核心思想是:对 KV 进行低秩压缩,只缓存压缩后的 latent 表示,在计算时再展开。
与传统 GQA 的关键区别
| 维度 | 传统 GQA | MLA |
|---|---|---|
| KV Cache 存储 | 完整的 | 压缩后 维 latent |
| 共享方式 | Q 头分组共享 K/V | 所有 Q 头共享 latent KV |
| 展开方式 | 无需展开 | 通过 up-projection 展开 |
| 额外操作 | 无 | Matrix Absorption |
数学原理
MLA 将 K 和 V 的投影矩阵分解为两个低秩矩阵:
其中:
- (down-projection,压缩)
- (up-projection,展开)
- (latent 维度)
推理时,KV Cache 仅存储 latent 表示:
C_{KV} = W_{DK} \cdot [h_1, h_2, ..., h_t] \quad \text{(shape: [seq_len, d_l])}
Matrix Absorption
这是 MLA 最关键也最精妙的优化。传统 naive 做法是先展开 latent KV 再算注意力:
K = C_KV @ W_UK # [seq, d_model] — 展开,浪费带宽
A = Q @ K^T # 注意力计算
Matrix Absorption 将 up-projection 吸收进 Q:
Q_absorbed = Q @ W_UK^T # 先吸收:Q [d_model] → Q [d_l]
A = Q_absorbed @ C_KV^T # 在 latent space 计算
这样无需显式展开 latent KV——计算的中间状态在更低的 维空间中就完成了。这需要自定义 CUDA kernel 来实现,但能带来显著的带宽节省。
KV Cache 对比(DeepSeek-V2 参数)
| 方法 | 每层 KV Cache 大小 | 节省 |
|---|---|---|
| GQA (H=128, G=16, d_h=128) | 2 × 16 × 128 = 4096 个元素 | 基准 |
| MLA (d_l=512) | 2 × 512 = 1024 个元素 | 4× |
| MHA (H=128) | 2 × 128 × 128 = 32768 个元素 | 8× (vs GQA) |
DeepSeek-V2 实测 MLA 可减少约 75% 的 KV Cache 占用,同时保持与全注意力相当的质量。
📊 第六章:总结与速查
核心公式速查
| 概念 | 公式 |
|---|---|
| Head 维度 | |
| GQA 分组数 | |
| 每组 Q 头数 | |
| 整除约束 | |
| KV Cache 节省倍率 | |
| MLA latent 维度 |
选择指南
| 场景 | 推荐方案 | 理由 |
|---|---|---|
| 小模型 (1B-7B) | MHA | KV Cache 压力小,质量优先 |
| 大模型 (7B-70B) 推理 | GQA (G=8) | 质量和速度的最佳平衡 |
| 超大模型 (70B+) 推理 | GQA (G=4~8) | 更大模型需更多 KV Cache 优化 |
| 极致推理吞吐 | MQA | 质量可牺牲时 |
| 长上下文 (128K+) | MLA 或 KV 量化 | KV Cache 成为主要瓶颈 |
| 显存受限部署 | GQA (G=4) + KV Cache 量化 | 双重压缩 |
推荐阅读
- 📖 Attention Is All You Need(Vaswani et al., 2017)— Transformer 原始论文,MHA 的起源
- 📖 Fast Transformer Decoding: One Write-Head is All You Need(Shazeer, 2019)— MQA 论文
- 📖 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints(Ainslie et al., 2023)— GQA 论文
- 📖 DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — MLA 论文
- 📖 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — IO-aware 注意力计算的基础
🚀 下一步
理解头部关系后,你可以:
- 动手改 config:试试把一个 MHA 模型改为 GQA(通过平均池化合并 K/V head),感受质量损失和速度提升
- 算你的 KV Cache:用上面的公式算算自己的模型在目标序列长度和 batch size 下需要多大显存
- 用代码验证:在推理引擎(vLLM / llama.cpp)中比较 MHA/GQA/MQA 的实际吞吐差异
理解了"头"的关系,你就能在模型配置层面做出明智的推理优化决策——而不是只知道"跑起来"。
📊 本文配套图解
图1:MHA vs MQA vs GQA 对比图 — Q/K/V 头数和映射关系 在线查看:https://excalidraw.com/#json=0osGQ_FXYcqF7-T9nZMRR,Ys0w7R8wGVTctOOLQnh00A 文件位置:
images/mha-vs-mqa-vs-gqa.excalidraw
图2:注意力头数量约束图解 — 整除约束、GQA 连续谱、实战示例 在线查看:https://excalidraw.com/#json=sQYUP3nTQ51NTjdq70GV0,HvWn_12aT6iLzRzHVE_bsQ 文件位置:
images/head-count-constraints.excalidraw
图3:MLA — Multi-head Latent Attention 原理 在线查看:https://excalidraw.com/#json=91-QXaL8Ll0TjoKy6fuUy,PLJ4c6oIEyb7uNHEkH1wJw 文件位置:
images/mla-principle.excalidraw