并行运行多个注意力头,从不同表示子空间捕获信息,是Transformer的关键创新
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
标准多头注意力实现
"""
def __init__(
self,
d_model=512,
n_heads=8,
dropout=0.1,
bias=True
):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# QKV投影(一次性投影,更高效)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=bias)
# 输出投影
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
self.scale = 1.0 / math.sqrt(self.d_k)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
# 1. QKV投影并reshape为多头
qkv = self.qkv_proj(x) # [batch, seq, 3*d_model]
qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_k]
q, k, v = qkv[0], qkv[1], qkv[2]
# 2. 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# attn_scores: [batch, heads, seq, seq]
# 3. 应用掩码(如果有)
if mask is not None:
# 扩展mask以匹配多头维度
mask = mask.unsqueeze(1).unsqueeze(1)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# 4. Softmax归一化
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 5. 应用注意力权重到值
context = torch.matmul(attn_weights, v)
# context: [batch, heads, seq, d_k]
# 6. 拼接多头输出
context = context.transpose(1, 2).contiguous()
context = context.reshape(batch_size, seq_len, d_model)
# 7. 输出投影
output = self.out_proj(context)
return output, attn_weights
class GroupedQueryAttention(nn.Module):
"""
分组查询注意力 - Llama 3/Mistral风格
"""
def __init__(
self,
d_model=4096,
n_heads=32,
n_kv_heads=8, # KV头的数量(组数)
dropout=0.1
):
super().__init__()
assert n_heads % n_kv_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_groups = n_heads // n_kv_heads # 每组的查询头数
self.d_k = d_model // n_heads
# 分别投影Q和KV(KV的头数更少)
self.q_proj = nn.Linear(d_model, n_heads * self.d_k, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.scale = 1.0 / math.sqrt(self.d_k)
def forward(self, x, use_cache=False, past_kv=None):
batch_size, seq_len, _ = x.shape
# 投影查询
q = self.q_proj(x)
q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
q = q.transpose(1, 2) # [batch, n_heads, seq, d_k]
# 投影键值(头数更少)
k = self.k_proj(x)
v = self.v_proj(x)
k = k.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
v = v.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
k = k.transpose(1, 2) # [batch, n_kv_heads, seq, d_k]
v = v.transpose(1, 2)
# 处理KV缓存(推理优化)
if use_cache and past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
# 重复KV头以匹配查询头数量
if self.n_groups > 1:
k = k.repeat_interleave(self.n_groups, dim=1)
v = v.repeat_interleave(self.n_groups, dim=1)
# 计算注意力
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_scores, dim=-1)
# 应用注意力
output = torch.matmul(attn_weights, v)
# 重塑输出
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, self.d_model)
output = self.out_proj(output)
if use_cache:
# 只返回未重复的KV用于缓存
return output, (k[:, :self.n_kv_heads], v[:, :self.n_kv_heads])
return output
class MultiQueryAttention(nn.Module):
"""
多查询注意力 - PaLM/StarCoder风格
所有查询头共享单一KV对
"""
def __init__(
self,
d_model=2048,
n_heads=16,
dropout=0.1
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 多个查询头,单一KV头
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, self.d_k, bias=False) # 单一K
self.v_proj = nn.Linear(d_model, self.d_k, bias=False) # 单一V
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.scale = 1.0 / math.sqrt(self.d_k)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 多头查询
q = self.q_proj(x)
q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
q = q.transpose(1, 2)
# 单头键值
k = self.k_proj(x) # [batch, seq, d_k]
v = self.v_proj(x) # [batch, seq, d_k]
# 扩展k和v以匹配多头查询
k = k.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
v = v.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
# 标准注意力计算
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, self.d_model)
return self.out_proj(output)
def compare_attention_variants():
"""
比较不同注意力变体的内存和计算效率
"""
batch_size = 1
seq_len = 2048
d_model = 4096
n_heads = 32
# 标准MHA
mha = MultiHeadAttention(d_model, n_heads)
mha_params = sum(p.numel() for p in mha.parameters())
# GQA (8组)
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_heads=8)
gqa_params = sum(p.numel() for p in gqa.parameters())
# MQA
mqa = MultiQueryAttention(d_model, n_heads)
mqa_params = sum(p.numel() for p in mqa.parameters())
print(f"参数量对比:")
print(f"MHA: {mha_params:,} 参数")
print(f"GQA: {gqa_params:,} 参数 (减少 {(1-gqa_params/mha_params)*100:.1f}%)")
print(f"MQA: {mqa_params:,} 参数 (减少 {(1-mqa_params/mha_params)*100:.1f}%)")
# KV缓存大小计算(推理时)
kv_cache_mha = 2 * batch_size * n_heads * seq_len * (d_model // n_heads) * 4 / (1024**2) # MB
kv_cache_gqa = 2 * batch_size * 8 * seq_len * (d_model // n_heads) * 4 / (1024**2)
kv_cache_mqa = 2 * batch_size * 1 * seq_len * (d_model // n_heads) * 4 / (1024**2)
print(f"\nKV缓存大小 (seq_len={seq_len}):")
print(f"MHA: {kv_cache_mha:.1f} MB")
print(f"GQA: {kv_cache_gqa:.1f} MB (减少 {(1-kv_cache_gqa/kv_cache_mha)*100:.1f}%)")
print(f"MQA: {kv_cache_mqa:.1f} MB (减少 {(1-kv_cache_mqa/kv_cache_mha)*100:.1f}%)")
class AttentionHeadAnalyzer:
"""
分析不同注意力头学到的模式
"""
def analyze_attention_patterns(self, model, input_ids, layer_idx=0):
"""
可视化特定层的注意力模式
"""
with torch.no_grad():
# 获取注意力权重
outputs = model(input_ids, output_attentions=True)
attention_weights = outputs.attentions[layer_idx]
# attention_weights: [batch, heads, seq, seq]
# 分析每个头的模式
patterns = []
for head_idx in range(attention_weights.size(1)):
head_attn = attention_weights[0, head_idx]
# 计算注意力熵(集中度)
entropy = -(head_attn * torch.log(head_attn + 1e-9)).sum(dim=-1).mean()
# 计算平均注意力距离
positions = torch.arange(head_attn.size(0))
distances = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
avg_distance = (head_attn * distances).sum() / head_attn.sum()
patterns.append({
'head': head_idx,
'entropy': entropy.item(),
'avg_distance': avg_distance.item(),
'pattern_type': self.classify_pattern(head_attn)
})
return patterns
def classify_pattern(self, attn_matrix):
"""
分类注意力模式类型
"""
seq_len = attn_matrix.size(0)
# 检查是否是位置注意力(对角线)
diagonal_weight = torch.diagonal(attn_matrix).mean()
if diagonal_weight > 0.5:
return "positional"
# 检查是否是全局注意力([CLS]或[SEP])
first_row_weight = attn_matrix[0].mean()
if first_row_weight > 0.3:
return "global"
# 检查是否是前向/后向注意力
lower_tri = torch.tril(attn_matrix, diagonal=-1).sum()
upper_tri = torch.triu(attn_matrix, diagonal=1).sum()
if lower_tri > upper_tri * 2:
return "backward"
elif upper_tri > lower_tri * 2:
return "forward"
return "mixed"
特性 | MHA | MQA | GQA-4 | GQA-8 |
---|---|---|---|---|
查询头数 | 32 | 32 | 32 | 32 |
KV头数 | 32 | 1 | 8 | 4 |
KV缓存大小 | 100% | 3.1% | 25% | 12.5% |
推理速度 | 1x | 8x | 3x | 5x |
模型质量 | 最佳 | 一般 | 优秀 | 很好 |
代表模型 | BERT | PaLM | Llama 3 | Mixtral |