Transformer的核心创新,让模型能够直接建模序列内部任意位置之间的关系
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""
标准自注意力机制实现
"""
def __init__(self, d_model=512, dropout=0.1):
super().__init__()
self.d_model = d_model
# Q、K、V的线性变换
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
# 输出投影
self.out_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(d_model)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
# 1. 线性变换生成Q、K、V
Q = self.q_linear(x) # [batch, seq_len, d_model]
K = self.k_linear(x)
V = self.v_linear(x)
# 2. 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# scores: [batch, seq_len, seq_len]
# 3. 应用掩码(如果有)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Softmax归一化
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 5. 加权聚合值向量
output = torch.matmul(attn_weights, V)
# 6. 输出投影
output = self.out_linear(output)
return output, attn_weights
class FlashAttention3(nn.Module):
"""
Flash Attention 3 with FP8 and asynchronous computation
注:这是概念示例,实际需要CUDA kernel实现
"""
def __init__(self, d_model=512, block_size=64, use_fp8=True):
super().__init__()
self.d_model = d_model
self.block_size = block_size
self.use_fp8 = use_fp8
# QKV投影(支持FP8)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
if use_fp8:
# Hadamard变换矩阵(用于离群值处理)
self.register_buffer(
'hadamard_matrix',
self.generate_hadamard_matrix(d_model)
)
def generate_hadamard_matrix(self, dim):
"""生成Hadamard矩阵用于incoherent processing"""
def hadamard(n):
if n == 1:
return torch.tensor([[1.0]])
h_n_minus_1 = hadamard(n // 2)
top = torch.cat([h_n_minus_1, h_n_minus_1], dim=1)
bottom = torch.cat([h_n_minus_1, -h_n_minus_1], dim=1)
return torch.cat([top, bottom], dim=0) / math.sqrt(2)
# 找到最接近的2的幂
n = 2 ** math.ceil(math.log2(dim))
H = hadamard(n)
return H[:dim, :dim]
def apply_incoherent_processing(self, x):
"""
应用Hadamard变换减少量化误差
"""
# 随机符号
random_signs = torch.randint(0, 2, (x.shape[-1],),
device=x.device) * 2 - 1
# 应用Hadamard变换
x = x * random_signs
x = torch.matmul(x, self.hadamard_matrix)
return x, random_signs
def flash_attention_kernel(self, Q, K, V, block_size):
"""
Flash Attention核心算法(分块计算)
"""
batch_size, seq_len, d_model = Q.shape
# 初始化输出和统计量
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, seq_len, device=Q.device)
M = torch.full((batch_size, seq_len), -float('inf'), device=Q.device)
# 分块处理
for i in range(0, seq_len, block_size):
Q_block = Q[:, i:i+block_size]
# 重新计算统计量
M_new = M[:, i:i+block_size].clone()
L_new = L[:, i:i+block_size].clone()
for j in range(0, seq_len, block_size):
K_block = K[:, j:j+block_size]
V_block = V[:, j:j+block_size]
# 计算注意力分数(使用FP8如果启用)
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
S_block = S_block / math.sqrt(d_model)
# 更新最大值(数值稳定性)
M_block = S_block.max(dim=-1, keepdim=True)[0]
M_new = torch.maximum(M_new.unsqueeze(-1), M_block).squeeze(-1)
# 计算exp(S - M_new)
P_block = torch.exp(S_block - M_new.unsqueeze(-1))
# 更新L(归一化因子)
L_new = L_new * torch.exp(M[:, i:i+block_size] - M_new) + \
P_block.sum(dim=-1)
# 更新输出
O[:, i:i+block_size] = (
O[:, i:i+block_size] *
torch.exp(M[:, i:i+block_size] - M_new).unsqueeze(-1) +
torch.matmul(P_block, V_block)
) / L_new.unsqueeze(-1)
# 更新M
M[:, i:i+block_size] = M_new
L[:, i:i+block_size] = L_new
return O
def forward(self, x):
batch_size, seq_len, d_model = x.shape
# QKV投影
qkv = self.qkv_proj(x)
Q, K, V = qkv.chunk(3, dim=-1)
if self.use_fp8:
# 应用incoherent processing(FP8量化前)
Q, q_signs = self.apply_incoherent_processing(Q)
K, k_signs = self.apply_incoherent_processing(K)
# 这里应该进行FP8量化(需要硬件支持)
# Q_fp8 = quantize_to_fp8(Q)
# K_fp8 = quantize_to_fp8(K)
# Flash Attention计算
output = self.flash_attention_kernel(Q, K, V, self.block_size)
if self.use_fp8:
# 逆Hadamard变换
output = torch.matmul(output, self.hadamard_matrix.T)
output = output * q_signs
return output
class CausalSelfAttention(nn.Module):
"""
带因果掩码的自注意力(用于自回归生成)
"""
def __init__(self, d_model=512, max_seq_len=1024):
super().__init__()
self.d_model = d_model
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
# 创建因果掩码
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(max_seq_len, max_seq_len))
.view(1, 1, max_seq_len, max_seq_len)
)
def forward(self, x, use_cache=False, past_kv=None):
batch_size, seq_len, _ = x.shape
# QKV投影
qkv = self.qkv_proj(x)
q, k, v = qkv.split(self.d_model, dim=-1)
# 处理KV缓存(推理优化)
if use_cache and past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
# 计算注意力
attn_scores = torch.matmul(q, k.transpose(-2, -1))
attn_scores = attn_scores / math.sqrt(self.d_model)
# 应用因果掩码
causal_mask = self.causal_mask[:, :, :seq_len, :k.size(1)]
attn_scores = attn_scores.masked_fill(
causal_mask == 0,
float('-inf')
)
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = self.out_proj(output)
if use_cache:
return output, (k, v)
return output
class LongContextAttention(nn.Module):
"""
支持超长序列的注意力机制
"""
def __init__(
self,
d_model=512,
max_seq_len=1_000_000, # 百万级token
use_flash_attn=True,
use_sliding_window=False,
window_size=4096
):
super().__init__()
self.d_model = d_model
self.use_flash_attn = use_flash_attn
self.use_sliding_window = use_sliding_window
self.window_size = window_size
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
# RoPE位置编码(支持长序列)
self.rope = RotaryEmbedding(d_model, max_seq_len)
def sliding_window_attention(self, q, k, v, window_size):
"""
滑动窗口注意力(局部注意力)
"""
batch_size, seq_len, d_model = q.shape
output = torch.zeros_like(q)
for i in range(seq_len):
# 计算窗口范围
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
# 局部注意力计算
q_i = q[:, i:i+1]
k_window = k[:, start:end]
v_window = v[:, start:end]
scores = torch.matmul(q_i, k_window.transpose(-2, -1))
scores = scores / math.sqrt(d_model)
attn_weights = F.softmax(scores, dim=-1)
output[:, i] = torch.matmul(attn_weights, v_window).squeeze(1)
return output
def forward(self, x):
batch_size, seq_len, _ = x.shape
# QKV投影
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# 应用RoPE
q, k = self.rope(q, k)
if self.use_sliding_window and seq_len > self.window_size:
# 使用滑动窗口处理超长序列
output = self.sliding_window_attention(q, k, v, self.window_size)
elif self.use_flash_attn:
# 使用Flash Attention
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v)
else:
# 标准注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_model)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return self.out_proj(output)
实现方式 | 内存复杂度 | 时间复杂度 | 最大序列长度 | GPU利用率 |
---|---|---|---|---|
标准注意力 | O(N²) | O(N²) | ~4K | 20-30% |
Flash Attention | O(N) | O(N²) | ~32K | 40-50% |
Flash Attention 2 | O(N) | O(N²) | ~128K | 35% (H100) |
Flash Attention 3 | O(N) | O(N²) | ~1M | 75% (H100) |
滑动窗口 | O(N×W) | O(N×W) | 无限制 | 30-40% |