为Transformer提供序列顺序信息的关键技术,从正弦编码到旋转位置编码的演进
import torch
import torch.nn as nn
import math
class SinusoidalPositionalEncoding(nn.Module):
"""
原始Transformer的正弦位置编码
"""
def __init__(self, d_model=512, max_seq_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# 计算频率
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model)
)
# 应用正弦和余弦
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
# 注册为buffer(不参与训练)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
x + positional encoding
"""
seq_len = x.size(1)
return x + self.pe[:, :seq_len, :]
class RotaryPositionalEmbedding(nn.Module):
"""
旋转位置编码 - Llama/GPT-4风格
"""
def __init__(
self,
dim,
max_seq_len=8192,
base=10000,
device=None,
scaling_factor=1.0 # 用于位置插值
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
self.scaling_factor = scaling_factor
# 预计算频率
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float() / dim)
)
self.register_buffer("inv_freq", inv_freq)
# 预计算cos和sin缓存
self._set_cos_sin_cache(max_seq_len, device)
def _set_cos_sin_cache(self, seq_len, device):
"""预计算cos和sin值用于加速"""
# 应用位置插值缩放
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
t = t / self.scaling_factor
# 计算频率
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# 缓存cos和sin
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def rotate_half(self, x):
"""旋转输入张量的一半维度"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, q, k, seq_len=None):
"""
应用旋转位置编码
Args:
q: [batch, heads, seq_len, head_dim]
k: [batch, heads, seq_len, head_dim]
"""
if seq_len is None:
seq_len = q.shape[2]
# 如果序列长度超过缓存,重新计算
if seq_len > self.max_seq_len:
self._set_cos_sin_cache(seq_len, q.device)
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
# 应用旋转
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
class ALiBiPositionalBias(nn.Module):
"""
ALiBi (Attention with Linear Biases) - BLOOM风格
"""
def __init__(self, n_heads, max_seq_len=2048):
super().__init__()
self.n_heads = n_heads
# 计算每个头的斜率
slopes = self._get_slopes(n_heads)
self.register_buffer('slopes', slopes)
# 预计算偏置矩阵
bias = self._build_alibi_bias(max_seq_len)
self.register_buffer('bias', bias)
def _get_slopes(self, n_heads):
"""计算每个注意力头的斜率"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n_heads).is_integer():
return torch.tensor(get_slopes_power_of_2(n_heads))
else:
# 非2的幂次,插值处理
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
slopes_a = get_slopes_power_of_2(closest_power_of_2)
slopes_b = get_slopes_power_of_2(2 * closest_power_of_2)
# 插值
slopes = slopes_a + slopes_b[0::2][:n_heads - closest_power_of_2]
return torch.tensor(slopes)
def _build_alibi_bias(self, seq_len):
"""构建ALiBi偏置矩阵"""
# 相对位置矩阵
relative_positions = torch.arange(seq_len).unsqueeze(0) - \
torch.arange(seq_len).unsqueeze(1)
# 应用斜率
slopes = self.slopes.unsqueeze(1).unsqueeze(1)
bias = slopes * relative_positions.unsqueeze(0)
return bias
def forward(self, attention_scores, seq_len):
"""
应用ALiBi偏置
Args:
attention_scores: [batch, heads, seq_len, seq_len]
"""
if seq_len > self.bias.shape[-1]:
# 动态扩展偏置矩阵
self.bias = self._build_alibi_bias(seq_len).to(attention_scores.device)
bias = self.bias[:, :seq_len, :seq_len]
return attention_scores + bias
class PositionInterpolation:
"""
位置插值技术,用于扩展预训练模型的上下文窗口
"""
@staticmethod
def interpolate_rope(
model,
original_max_len=4096,
target_max_len=32768
):
"""
对RoPE模型应用位置插值
"""
scaling_factor = target_max_len / original_max_len
# 更新所有RoPE层的缩放因子
for module in model.modules():
if isinstance(module, RotaryPositionalEmbedding):
module.scaling_factor = scaling_factor
module.max_seq_len = target_max_len
# 重新计算cos/sin缓存
module._set_cos_sin_cache(target_max_len, module.inv_freq.device)
print(f"位置插值完成: {original_max_len} → {target_max_len}")
print(f"缩放因子: {scaling_factor}")
return model
@staticmethod
def fine_tune_for_long_context(
model,
train_dataloader,
target_seq_len=32768,
epochs=1000
):
"""
长上下文微调
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for epoch in range(epochs):
for batch in train_dataloader:
# 逐步增加序列长度
current_seq_len = min(
4096 * (1 + epoch / 100),
target_seq_len
)
# 截断或填充到当前长度
input_ids = batch['input_ids'][:, :int(current_seq_len)]
# 前向传播
outputs = model(input_ids)
loss = outputs.loss
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Seq Len: {current_seq_len}, Loss: {loss.item():.4f}")
class xPos(nn.Module):
"""
xPos: 结合RoPE和指数衰减,改善长度外推
"""
def __init__(
self,
dim,
max_seq_len=8192,
base=10000,
decay_base=512
):
super().__init__()
self.rope = RotaryPositionalEmbedding(dim, max_seq_len, base)
self.decay_base = decay_base
def apply_decay(self, q, k, seq_len):
"""应用指数衰减"""
# 计算衰减因子
positions = torch.arange(seq_len, device=q.device)
decay = (1 + positions / self.decay_base) ** -1
decay = decay.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
# 应用衰减
q = q * decay
k = k / decay # K使用倒数保持内积不变
return q, k
def forward(self, q, k):
# 先应用RoPE
q, k = self.rope(q, k)
# 再应用衰减
seq_len = q.shape[2]
if seq_len > self.decay_base:
q, k = self.apply_decay(q, k, seq_len)
return q, k
class RoPEFrequencyAnalysis:
"""
分析RoPE不同频率成分的作用(基于2024年研究)
"""
@staticmethod
def analyze_frequency_usage(model, input_ids):
"""
分析模型对RoPE不同频率的使用
"""
# 获取注意力权重
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
attentions = outputs.attentions # [layers, batch, heads, seq, seq]
# 分析每层每个头的频率偏好
frequency_importance = []
for layer_idx, layer_attn in enumerate(attentions):
layer_freq = []
for head_idx in range(layer_attn.size(1)):
head_attn = layer_attn[0, head_idx] # [seq, seq]
# FFT分析注意力模式的频率成分
fft_result = torch.fft.fft2(head_attn)
freq_magnitude = torch.abs(fft_result)
# 分离低频和高频成分
low_freq = freq_magnitude[:10, :10].mean()
high_freq = freq_magnitude[10:, 10:].mean()
layer_freq.append({
'head': head_idx,
'low_freq_importance': low_freq.item(),
'high_freq_importance': high_freq.item(),
'ratio': (low_freq / (high_freq + 1e-8)).item()
})
frequency_importance.append(layer_freq)
return frequency_importance
@staticmethod
def remove_lowest_frequencies(rope_module, num_freqs_to_remove=2):
"""
移除RoPE的最低频率(基于DeepMind 2024研究)
"""
with torch.no_grad():
# 将最低频率设为零
rope_module.inv_freq[:num_freqs_to_remove] = 0
# 重新计算cos/sin缓存
rope_module._set_cos_sin_cache(
rope_module.max_seq_len,
rope_module.inv_freq.device
)
print(f"已移除{num_freqs_to_remove}个最低频率成分")
方法 | 参数量 | 长度外推 | 计算成本 | 2024年采用 |
---|---|---|---|---|
正弦编码 | 0 | 一般 | 最低 | 基础模型 |
学习编码 | O(L×d) | 差 | 低 | 较少使用 |
RoPE | 0 | 良好 | 低 | 主流(80%+) |
ALiBi | 0 | 优秀 | 最低 | BLOOM系列 |
xPos | 0 | 优秀 | 低 | 实验阶段 |
RoPE+PI | 0 | 极好 | 低 | Llama 3 |