概念定义

多头注意力(Multi-Head Attention)通过并行运行多个独立的注意力头,让模型能够同时从不同的表示子空间学习信息,极大增强了模型捕获复杂模式的能力。

详细解释

多头注意力是Transformer成功的关键因素之一。与单一注意力机制不同,多头注意力将模型的表示空间分成多个子空间,每个”头”独立学习不同类型的依赖关系。比如在处理自然语言时,一个头可能关注语法关系,另一个头关注语义相似性,还有的头可能专门捕获长距离依赖。 2024年,多头注意力已经演化出多个重要变体。分组查询注意力(GQA)成为主流选择,被Llama 3、Mistral、Granite 3.0等模型采用。GQA通过在多个查询头之间共享键值对,在保持模型质量的同时显著降低了内存消耗。而多查询注意力(MQA)则走向极端,所有查询头共享单一的键值对,虽然速度更快但可能损失精度。 从GPT-3的96个注意力头到Llama 3的分组设计(8组共享KV),现代模型在头数量和组织方式上进行了精心优化。研究表明,并非头越多越好——关键在于找到计算效率和表达能力的平衡点。通过Flash Attention等优化技术,即使是上百个注意力头也能高效运行。

工作原理

多头注意力的工作流程:
  1. 输入分割:将d_model维度分成h个头,每个头处理d_k=d_model/h维
  2. 并行计算:每个头独立进行注意力计算
  3. 拼接输出:将所有头的输出拼接
  4. 线性投影:通过输出矩阵W_O映射回d_model维

实际应用

标准多头注意力实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """
    标准多头注意力实现
    """
    def __init__(
        self,
        d_model=512,
        n_heads=8,
        dropout=0.1,
        bias=True
    ):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度
        
        # QKV投影(一次性投影,更高效)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=bias)
        
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = 1.0 / math.sqrt(self.d_k)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        
        # 1. QKV投影并reshape为多头
        qkv = self.qkv_proj(x)  # [batch, seq, 3*d_model]
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch, heads, seq, d_k]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 2. 计算注意力分数
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # attn_scores: [batch, heads, seq, seq]
        
        # 3. 应用掩码(如果有)
        if mask is not None:
            # 扩展mask以匹配多头维度
            mask = mask.unsqueeze(1).unsqueeze(1)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # 4. Softmax归一化
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 5. 应用注意力权重到值
        context = torch.matmul(attn_weights, v)
        # context: [batch, heads, seq, d_k]
        
        # 6. 拼接多头输出
        context = context.transpose(1, 2).contiguous()
        context = context.reshape(batch_size, seq_len, d_model)
        
        # 7. 输出投影
        output = self.out_proj(context)
        
        return output, attn_weights

分组查询注意力(GQA)实现

class GroupedQueryAttention(nn.Module):
    """
    分组查询注意力 - Llama 3/Mistral风格
    """
    def __init__(
        self,
        d_model=4096,
        n_heads=32,
        n_kv_heads=8,  # KV头的数量(组数)
        dropout=0.1
    ):
        super().__init__()
        assert n_heads % n_kv_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads  # 每组的查询头数
        self.d_k = d_model // n_heads
        
        # 分别投影Q和KV(KV的头数更少)
        self.q_proj = nn.Linear(d_model, n_heads * self.d_k, bias=False)
        self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.scale = 1.0 / math.sqrt(self.d_k)
    
    def forward(self, x, use_cache=False, past_kv=None):
        batch_size, seq_len, _ = x.shape
        
        # 投影查询
        q = self.q_proj(x)
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
        q = q.transpose(1, 2)  # [batch, n_heads, seq, d_k]
        
        # 投影键值(头数更少)
        k = self.k_proj(x)
        v = self.v_proj(x)
        k = k.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
        v = v.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
        k = k.transpose(1, 2)  # [batch, n_kv_heads, seq, d_k]
        v = v.transpose(1, 2)
        
        # 处理KV缓存(推理优化)
        if use_cache and past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        # 重复KV头以匹配查询头数量
        if self.n_groups > 1:
            k = k.repeat_interleave(self.n_groups, dim=1)
            v = v.repeat_interleave(self.n_groups, dim=1)
        
        # 计算注意力
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 应用注意力
        output = torch.matmul(attn_weights, v)
        
        # 重塑输出
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, self.d_model)
        output = self.out_proj(output)
        
        if use_cache:
            # 只返回未重复的KV用于缓存
            return output, (k[:, :self.n_kv_heads], v[:, :self.n_kv_heads])
        
        return output

多查询注意力(MQA)实现

class MultiQueryAttention(nn.Module):
    """
    多查询注意力 - PaLM/StarCoder风格
    所有查询头共享单一KV对
    """
    def __init__(
        self,
        d_model=2048,
        n_heads=16,
        dropout=0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 多个查询头,单一KV头
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, self.d_k, bias=False)  # 单一K
        self.v_proj = nn.Linear(d_model, self.d_k, bias=False)  # 单一V
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.scale = 1.0 / math.sqrt(self.d_k)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 多头查询
        q = self.q_proj(x)
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
        q = q.transpose(1, 2)
        
        # 单头键值
        k = self.k_proj(x)  # [batch, seq, d_k]
        v = self.v_proj(x)  # [batch, seq, d_k]
        
        # 扩展k和v以匹配多头查询
        k = k.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
        v = v.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
        
        # 标准注意力计算
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, self.d_model)
        
        return self.out_proj(output)

高效注意力变体比较

def compare_attention_variants():
    """
    比较不同注意力变体的内存和计算效率
    """
    batch_size = 1
    seq_len = 2048
    d_model = 4096
    n_heads = 32
    
    # 标准MHA
    mha = MultiHeadAttention(d_model, n_heads)
    mha_params = sum(p.numel() for p in mha.parameters())
    
    # GQA (8组)
    gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads=8)
    gqa_params = sum(p.numel() for p in gqa.parameters())
    
    # MQA
    mqa = MultiQueryAttention(d_model, n_heads)
    mqa_params = sum(p.numel() for p in mqa.parameters())
    
    print(f"参数量对比:")
    print(f"MHA: {mha_params:,} 参数")
    print(f"GQA: {gqa_params:,} 参数 (减少 {(1-gqa_params/mha_params)*100:.1f}%)")
    print(f"MQA: {mqa_params:,} 参数 (减少 {(1-mqa_params/mha_params)*100:.1f}%)")
    
    # KV缓存大小计算(推理时)
    kv_cache_mha = 2 * batch_size * n_heads * seq_len * (d_model // n_heads) * 4 / (1024**2)  # MB
    kv_cache_gqa = 2 * batch_size * 8 * seq_len * (d_model // n_heads) * 4 / (1024**2)
    kv_cache_mqa = 2 * batch_size * 1 * seq_len * (d_model // n_heads) * 4 / (1024**2)
    
    print(f"\nKV缓存大小 (seq_len={seq_len}):")
    print(f"MHA: {kv_cache_mha:.1f} MB")
    print(f"GQA: {kv_cache_gqa:.1f} MB (减少 {(1-kv_cache_gqa/kv_cache_mha)*100:.1f}%)")
    print(f"MQA: {kv_cache_mqa:.1f} MB (减少 {(1-kv_cache_mqa/kv_cache_mha)*100:.1f}%)")

特殊功能头分析

class AttentionHeadAnalyzer:
    """
    分析不同注意力头学到的模式
    """
    def analyze_attention_patterns(self, model, input_ids, layer_idx=0):
        """
        可视化特定层的注意力模式
        """
        with torch.no_grad():
            # 获取注意力权重
            outputs = model(input_ids, output_attentions=True)
            attention_weights = outputs.attentions[layer_idx]
            # attention_weights: [batch, heads, seq, seq]
            
            # 分析每个头的模式
            patterns = []
            for head_idx in range(attention_weights.size(1)):
                head_attn = attention_weights[0, head_idx]
                
                # 计算注意力熵(集中度)
                entropy = -(head_attn * torch.log(head_attn + 1e-9)).sum(dim=-1).mean()
                
                # 计算平均注意力距离
                positions = torch.arange(head_attn.size(0))
                distances = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
                avg_distance = (head_attn * distances).sum() / head_attn.sum()
                
                patterns.append({
                    'head': head_idx,
                    'entropy': entropy.item(),
                    'avg_distance': avg_distance.item(),
                    'pattern_type': self.classify_pattern(head_attn)
                })
            
            return patterns
    
    def classify_pattern(self, attn_matrix):
        """
        分类注意力模式类型
        """
        seq_len = attn_matrix.size(0)
        
        # 检查是否是位置注意力(对角线)
        diagonal_weight = torch.diagonal(attn_matrix).mean()
        if diagonal_weight > 0.5:
            return "positional"
        
        # 检查是否是全局注意力([CLS]或[SEP])
        first_row_weight = attn_matrix[0].mean()
        if first_row_weight > 0.3:
            return "global"
        
        # 检查是否是前向/后向注意力
        lower_tri = torch.tril(attn_matrix, diagonal=-1).sum()
        upper_tri = torch.triu(attn_matrix, diagonal=1).sum()
        
        if lower_tri > upper_tri * 2:
            return "backward"
        elif upper_tri > lower_tri * 2:
            return "forward"
        
        return "mixed"

性能对比

特性MHAMQAGQA-4GQA-8
查询头数32323232
KV头数32184
KV缓存大小100%3.1%25%12.5%
推理速度1x8x3x5x
模型质量最佳一般优秀很好
代表模型BERTPaLMLlama 3Mixtral

相关概念

延伸阅读