Transformer解码器负责自回归生成,是GPT等生成式大模型的核心架构
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerDecoder(nn.Module):
"""
GPT风格的Decoder-Only Transformer
"""
def __init__(
self,
vocab_size=50257, # GPT-2词汇表大小
d_model=768, # 隐藏维度
n_layers=12, # 层数
n_heads=12, # 注意力头数
d_ff=3072, # FFN维度
max_seq_len=1024, # 最大序列长度
dropout=0.1
):
super().__init__()
# Token和位置嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# 解码器层堆栈
self.layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 输出投影
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"causal_mask",
self.generate_causal_mask(max_seq_len)
)
def generate_causal_mask(self, size):
"""
生成因果掩码矩阵
"""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def forward(self, input_ids, past_key_values=None):
batch_size, seq_len = input_ids.shape
# 位置ID
position_ids = torch.arange(seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# 嵌入
token_emb = self.token_embedding(input_ids)
pos_emb = self.position_embedding(position_ids)
x = self.dropout(token_emb + pos_emb)
# 因果掩码
causal_mask = self.causal_mask[:seq_len, :seq_len]
# 通过解码器层
presents = []
for i, layer in enumerate(self.layers):
past = past_key_values[i] if past_key_values else None
x, present = layer(x, causal_mask, past)
presents.append(present)
# 最终层归一化和输出
x = self.ln_f(x)
logits = self.lm_head(x)
return logits, presents
class DecoderLayer(nn.Module):
"""
单个解码器层
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# 掩码自注意力
self.self_attn = MaskedSelfAttention(d_model, n_heads, dropout)
# 前馈网络(GPT风格:使用GELU激活)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
# 层归一化(Pre-LN for stability)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, past_key_value=None):
# Pre-LN + 自注意力 + 残差
residual = x
x = self.ln1(x)
attn_out, present = self.self_attn(x, mask, past_key_value)
x = residual + self.dropout(attn_out)
# Pre-LN + FFN + 残差
residual = x
x = self.ln2(x)
ffn_out = self.ffn(x)
x = residual + ffn_out
return x, present
class MaskedSelfAttention(nn.Module):
"""
带KV缓存的掩码自注意力
"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
# QKV投影
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = self.d_head ** -0.5
def forward(self, x, mask=None, past_key_value=None):
batch_size, seq_len, d_model = x.shape
# 计算Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
# 转置为(batch, heads, seq, d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# KV缓存(用于推理加速)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present = (k, v)
# 注意力计算
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 应用因果掩码
if mask is not None:
attn_scores = attn_scores + mask
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 应用注意力权重
attn_out = torch.matmul(attn_weights, v)
# 重塑输出
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(batch_size, seq_len, d_model)
return self.out_proj(attn_out), present
class GPTGenerator:
"""
使用解码器进行文本生成
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt,
max_length=100,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.1
):
"""
自回归文本生成
"""
# 编码输入
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
past_key_values = None
for _ in range(max_length):
# 前向传播
logits, past_key_values = self.model(
input_ids[:, -1:] if past_key_values else input_ids,
past_key_values=past_key_values
)
# 获取最后一个token的logits
next_token_logits = logits[:, -1, :]
# 应用温度
next_token_logits = next_token_logits / temperature
# 应用重复惩罚
for token_id in set(input_ids[0].tolist()):
next_token_logits[0, token_id] /= repetition_penalty
# Top-p采样
next_token = self.top_p_sampling(next_token_logits, top_p)
# 拼接新token
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# 检查是否生成结束token
if next_token.item() == self.tokenizer.eos_token_id:
break
return self.tokenizer.decode(input_ids[0])
def top_p_sampling(self, logits, top_p):
"""
Top-p (nucleus) 采样
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 找到累积概率超过top_p的位置
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# 将要移除的token的logits设为-inf
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = float('-inf')
# 采样
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token
def train_decoder_model(model, dataloader, epochs=10):
"""
训练解码器模型(因果语言建模)
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids']
# 前向传播
logits, _ = model(input_ids[:, :-1])
# 计算损失(标签是输入左移一位)
labels = input_ids[:, 1:]
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1)
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
模型 | 架构 | 参数量 | 上下文长度 | 特点 |
---|---|---|---|---|
GPT-2 | Decoder-only | 1.5B | 1024 | 首个大规模decoder模型 |
GPT-3 | Decoder-only | 175B | 2048 | Few-shot能力涌现 |
GPT-4 | Decoder-only | ~1.7T | 128k | 多模态,长上下文 |
LLaMA-3 | Decoder-only | 70B | 8192 | 开源,GQA优化 |
Claude-3 | Decoder-only | - | 200k | 超长上下文 |