概念定义
多头注意力(Multi-Head Attention)通过并行运行多个独立的注意力头,让模型能够同时从不同的表示子空间学习信息,极大增强了模型捕获复杂模式的能力。详细解释
多头注意力是Transformer成功的关键因素之一。与单一注意力机制不同,多头注意力将模型的表示空间分成多个子空间,每个”头”独立学习不同类型的依赖关系。比如在处理自然语言时,一个头可能关注语法关系,另一个头关注语义相似性,还有的头可能专门捕获长距离依赖。 2024年,多头注意力已经演化出多个重要变体。分组查询注意力(GQA)成为主流选择,被Llama 3、Mistral、Granite 3.0等模型采用。GQA通过在多个查询头之间共享键值对,在保持模型质量的同时显著降低了内存消耗。而多查询注意力(MQA)则走向极端,所有查询头共享单一的键值对,虽然速度更快但可能损失精度。 从GPT-3的96个注意力头到Llama 3的分组设计(8组共享KV),现代模型在头数量和组织方式上进行了精心优化。研究表明,并非头越多越好——关键在于找到计算效率和表达能力的平衡点。通过Flash Attention等优化技术,即使是上百个注意力头也能高效运行。工作原理
🧠 多头注意力机制架构
多头注意力通过h个并行的注意力头,从不同角度理解输入序列:📊 输入嵌入分割为多个头
输入嵌入 (d_model = 512) 被分割为 8 个头:每个头的维度: dk = d_model / h = 512 / 8 = 64
头1 (64)
头2 (64)
头3 (64)
…
🔄 各头独立进行注意力计算
🔗 头1:语法关系
线性投影: Q₁, K₁, V₁ (每个64×64维)注意力计算:专注模式:
- 主谓关系
- 修饰结构
- 语法依存
🎯 头2:语义相似
线性投影: Q₂, K₂, V₂ (每个64×64维)注意力计算:专注模式:
- 词义相似
- 同义词组
- 概念关联
其他头 (3-8): 分别专注于长距离依赖、位置关系、共指消解、语义角色等不同语言模式
🎯 输出拼接与最终投影
拼接所有头的输出:最终投影:其中 W_O 是 (h×dk) × d_model 的权重矩阵
⚡ 2024年现代变体对比
🧠 多头注意力 (MHA)
特点:典型应用: GPT-3, BERT
- 每个头独立Q、K、V
- 最高精度表现
- 内存消耗最大
🎯 分组查询注意力 (GQA)
特点:典型应用: Llama 3, Mistral
- Q头分组共享K、V
- 平衡精度和效率
- 内存节省50-75%
⚡ 多查询注意力 (MQA)
特点:典型应用: PaLM, Falcon
- 所有Q头共享单一K、V
- 最快推理速度
- 内存占用最少
📊 性能对比表格
变体 | 精度 | 推理速度 | 内存消耗 | KV Cache | 典型应用 |
---|---|---|---|---|---|
MHA | 最高 | 基准 | 最大 | h组KV | GPT-3/4, BERT |
GQA | 高 | 1.5-2x | 0.25-0.5x | g组KV | Llama 3, Mistral |
MQA | 中等 | 2-3x | 最少 | 1组KV | PaLM, Falcon |
2024年趋势: GQA成为主流选择,在Llama 3中采用8组共享设计,既保持了模型质量又显著降低了推理成本
🎯 多头注意力的工作流程
- 输入分割:将d_model维度分成h个头,每个头处理d_k=d_model/h维
- 并行注意力:各头独立计算自注意力,专注不同语言模式
- 拼接输出:将所有头的输出拼接
- 线性投影:通过输出矩阵W_O映射回d_model维
实际应用
标准多头注意力实现
分组查询注意力(GQA)实现
多查询注意力(MQA)实现
高效注意力变体比较
特殊功能头分析
性能对比
特性 | MHA | MQA | GQA-4 | GQA-8 |
---|---|---|---|---|
查询头数 | 32 | 32 | 32 | 32 |
KV头数 | 32 | 1 | 8 | 4 |
KV缓存大小 | 100% | 3.1% | 25% | 12.5% |
推理速度 | 1x | 8x | 3x | 5x |
模型质量 | 最佳 | 一般 | 优秀 | 很好 |
代表模型 | BERT | PaLM | Llama 3 | Mixtral |
相关概念
- 自注意力机制 - 基础机制
- Transformer架构 - 整体架构
- Flash Attention - 高效实现
- 分组查询注意力 - GQA详解
- KV缓存 - 推理优化