Documentation Index
Fetch the complete documentation index at: https://docs.apiyi.com/llms.txt
Use this file to discover all available pages before exploring further.
概念定义
解码器(Decoder)是Transformer架构中负责序列生成的组件,通过因果掩码实现自回归生成,确保每个位置只能关注之前的位置,是现代生成式大语言模型的基础架构。
详细解释
Transformer解码器是生成式AI革命的核心。与编码器的双向理解不同,解码器采用单向(从左到右)的注意力机制,通过因果掩码(Causal Masking)防止模型”偷看”未来信息。这种设计使其天然适合文本生成任务:每次预测下一个token时,只能依赖已生成的内容。
解码器架构的成功始于GPT系列。从GPT-1的117M参数到GPT-4的万亿级参数,解码器-only架构已成为大语言模型的主流选择。2024年的研究表明,在零样本泛化任务上,因果解码器模型配合自回归语言建模目标展现出最优性能,这解释了为什么ChatGPT、Claude、LLaMA等顶级模型都采用这一架构。
标准解码器包含三个关键组件:掩码自注意力(防止信息泄露)、编码器-解码器交叉注意力(在seq2seq任务中)、前馈网络。而在GPT等decoder-only模型中,去除了交叉注意力,形成更简洁高效的架构。
工作原理
🏗️ 解码器架构详解
🔒 因果掩码机制
解码器的核心特征是因果掩码(Causal Masking),确保生成过程的自回归特性:
因果掩码原理:生成序列 “The cat sat on”时:
- 预测 “cat” 时只能看到 “The”
- 预测 “sat” 时只能看到 “The cat”
- 预测 “on” 时只能看到 “The cat sat”
这确保了模型无法”偷看”未来信息,保持生成的合理性
🔄 解码器层结构
解码器有两种主要架构设计:
🏗️ 标准解码器 (Seq2Seq)
三个子层结构:
- 掩码自注意力: 处理目标序列,防止信息泄露
- 编码器-解码器注意力: 关注源序列信息
- 前馈网络: 非线性特征变换
应用: 机器翻译、文档摘要等 ⚡ Decoder-Only (GPT)
简化结构:
- 掩码自注意力: 处理输入序列
- 前馈网络: 特征提取
优势:应用: ChatGPT、Claude等
🎯 自回归生成过程
- 输入prompt: “What is artificial”
- 添加位置编码
- 通过embedding层转换为向量
- 掩码自注意力:每个位置只能看到之前的token
- 多层解码器逐步提取语义特征
- 生成上下文感知的表示
- 线性投影到词汇表大小
- Softmax生成概率分布
- 采样策略选择下一个token(如”intelligence”)
- 将新token添加到序列:“What is artificial intelligence”
- 重复步骤2-3,继续生成
- 直到遇到结束token或达到最大长度
⚖️ 编码器 vs 解码器对比
| 特性 | 编码器 | 解码器 |
|---|
| 注意力方向 | 双向 | 单向(因果) |
| 主要任务 | 理解、分类、表示学习 | 生成、续写、对话 |
| 典型模型 | BERT、RoBERTa | GPT、LLaMA、Claude |
| 预训练目标 | 掩码语言模型(MLM) | 自回归语言模型(ALM) |
| 推理方式 | 并行处理 | 序列生成 |
| 应用场景 | 搜索、分类、问答理解 | 聊天、创作、代码生成 |
🔬 现代解码器优化
2024年解码器发展趋势:
- 架构简化: Decoder-only成为主流
- 效率优化: Flash Attention、分组查询注意力
- 参数扩展: 从千亿到万亿参数规模
- 多模态融合: 统一处理文本、图像、音频
🎯 解码器的核心机制:
- 因果掩码:确保单向信息流,防止信息泄露
- 自回归生成:逐个token生成,每步基于之前的输出
- 下一个token预测:自监督训练目标
- 简化架构:decoder-only去除了交叉注意力
实际应用
基础解码器实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerDecoder(nn.Module):
"""
GPT风格的Decoder-Only Transformer
"""
def __init__(
self,
vocab_size=50257, # GPT-2词汇表大小
d_model=768, # 隐藏维度
n_layers=12, # 层数
n_heads=12, # 注意力头数
d_ff=3072, # FFN维度
max_seq_len=1024, # 最大序列长度
dropout=0.1
):
super().__init__()
# Token和位置嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# 解码器层堆栈
self.layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
# 输出投影
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"causal_mask",
self.generate_causal_mask(max_seq_len)
)
def generate_causal_mask(self, size):
"""
生成因果掩码矩阵
"""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
def forward(self, input_ids, past_key_values=None):
batch_size, seq_len = input_ids.shape
# 位置ID
position_ids = torch.arange(seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# 嵌入
token_emb = self.token_embedding(input_ids)
pos_emb = self.position_embedding(position_ids)
x = self.dropout(token_emb + pos_emb)
# 因果掩码
causal_mask = self.causal_mask[:seq_len, :seq_len]
# 通过解码器层
presents = []
for i, layer in enumerate(self.layers):
past = past_key_values[i] if past_key_values else None
x, present = layer(x, causal_mask, past)
presents.append(present)
# 最终层归一化和输出
x = self.ln_f(x)
logits = self.lm_head(x)
return logits, presents
解码器层实现
class DecoderLayer(nn.Module):
"""
单个解码器层
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# 掩码自注意力
self.self_attn = MaskedSelfAttention(d_model, n_heads, dropout)
# 前馈网络(GPT风格:使用GELU激活)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
# 层归一化(Pre-LN for stability)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, past_key_value=None):
# Pre-LN + 自注意力 + 残差
residual = x
x = self.ln1(x)
attn_out, present = self.self_attn(x, mask, past_key_value)
x = residual + self.dropout(attn_out)
# Pre-LN + FFN + 残差
residual = x
x = self.ln2(x)
ffn_out = self.ffn(x)
x = residual + ffn_out
return x, present
掩码自注意力实现
class MaskedSelfAttention(nn.Module):
"""
带KV缓存的掩码自注意力
"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
# QKV投影
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = self.d_head ** -0.5
def forward(self, x, mask=None, past_key_value=None):
batch_size, seq_len, d_model = x.shape
# 计算Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
# 转置为(batch, heads, seq, d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# KV缓存(用于推理加速)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present = (k, v)
# 注意力计算
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 应用因果掩码
if mask is not None:
attn_scores = attn_scores + mask
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 应用注意力权重
attn_out = torch.matmul(attn_weights, v)
# 重塑输出
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(batch_size, seq_len, d_model)
return self.out_proj(attn_out), present
文本生成实现
class GPTGenerator:
"""
使用解码器进行文本生成
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt,
max_length=100,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.1
):
"""
自回归文本生成
"""
# 编码输入
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
past_key_values = None
for _ in range(max_length):
# 前向传播
logits, past_key_values = self.model(
input_ids[:, -1:] if past_key_values else input_ids,
past_key_values=past_key_values
)
# 获取最后一个token的logits
next_token_logits = logits[:, -1, :]
# 应用温度
next_token_logits = next_token_logits / temperature
# 应用重复惩罚
for token_id in set(input_ids[0].tolist()):
next_token_logits[0, token_id] /= repetition_penalty
# Top-p采样
next_token = self.top_p_sampling(next_token_logits, top_p)
# 拼接新token
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# 检查是否生成结束token
if next_token.item() == self.tokenizer.eos_token_id:
break
return self.tokenizer.decode(input_ids[0])
def top_p_sampling(self, logits, top_p):
"""
Top-p (nucleus) 采样
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 找到累积概率超过top_p的位置
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# 将要移除的token的logits设为-inf
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = float('-inf')
# 采样
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token
训练循环
def train_decoder_model(model, dataloader, epochs=10):
"""
训练解码器模型(因果语言建模)
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids']
# 前向传播
logits, _ = model(input_ids[:, :-1])
# 计算损失(标签是输入左移一位)
labels = input_ids[:, 1:]
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1)
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
性能对比
| 模型 | 架构 | 参数量 | 上下文长度 | 特点 |
|---|
| GPT-2 | Decoder-only | 1.5B | 1024 | 首个大规模decoder模型 |
| GPT-3 | Decoder-only | 175B | 2048 | Few-shot能力涌现 |
| GPT-4 | Decoder-only | ~1.7T | 128k | 多模态,长上下文 |
| LLaMA-3 | Decoder-only | 70B | 8192 | 开源,GQA优化 |
| Claude-3 | Decoder-only | - | 200k | 超长上下文 |
相关概念
延伸阅读