概念定义

解码器(Decoder)是Transformer架构中负责序列生成的组件,通过因果掩码实现自回归生成,确保每个位置只能关注之前的位置,是现代生成式大语言模型的基础架构。

详细解释

Transformer解码器是生成式AI革命的核心。与编码器的双向理解不同,解码器采用单向(从左到右)的注意力机制,通过因果掩码(Causal Masking)防止模型”偷看”未来信息。这种设计使其天然适合文本生成任务:每次预测下一个token时,只能依赖已生成的内容。 解码器架构的成功始于GPT系列。从GPT-1的117M参数到GPT-4的万亿级参数,解码器-only架构已成为大语言模型的主流选择。2024年的研究表明,在零样本泛化任务上,因果解码器模型配合自回归语言建模目标展现出最优性能,这解释了为什么ChatGPT、Claude、LLaMA等顶级模型都采用这一架构。 标准解码器包含三个关键组件:掩码自注意力(防止信息泄露)、编码器-解码器交叉注意力(在seq2seq任务中)、前馈网络。而在GPT等decoder-only模型中,去除了交叉注意力,形成更简洁高效的架构。

工作原理

解码器的核心机制:
  1. 因果掩码:确保单向信息流,防止信息泄露
  2. 自回归生成:逐个token生成,每步基于之前的输出
  3. 下一个token预测:自监督训练目标
  4. 简化架构:decoder-only去除了交叉注意力

实际应用

基础解码器实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerDecoder(nn.Module):
    """
    GPT风格的Decoder-Only Transformer
    """
    def __init__(
        self,
        vocab_size=50257,     # GPT-2词汇表大小
        d_model=768,          # 隐藏维度
        n_layers=12,          # 层数
        n_heads=12,           # 注意力头数
        d_ff=3072,            # FFN维度
        max_seq_len=1024,     # 最大序列长度
        dropout=0.1
    ):
        super().__init__()
        
        # Token和位置嵌入
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        
        # 解码器层堆栈
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        # 输出投影
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "causal_mask",
            self.generate_causal_mask(max_seq_len)
        )
    
    def generate_causal_mask(self, size):
        """
        生成因果掩码矩阵
        """
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def forward(self, input_ids, past_key_values=None):
        batch_size, seq_len = input_ids.shape
        
        # 位置ID
        position_ids = torch.arange(seq_len, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        
        # 嵌入
        token_emb = self.token_embedding(input_ids)
        pos_emb = self.position_embedding(position_ids)
        x = self.dropout(token_emb + pos_emb)
        
        # 因果掩码
        causal_mask = self.causal_mask[:seq_len, :seq_len]
        
        # 通过解码器层
        presents = []
        for i, layer in enumerate(self.layers):
            past = past_key_values[i] if past_key_values else None
            x, present = layer(x, causal_mask, past)
            presents.append(present)
        
        # 最终层归一化和输出
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits, presents

解码器层实现

class DecoderLayer(nn.Module):
    """
    单个解码器层
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # 掩码自注意力
        self.self_attn = MaskedSelfAttention(d_model, n_heads, dropout)
        
        # 前馈网络(GPT风格:使用GELU激活)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        # 层归一化(Pre-LN for stability)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None, past_key_value=None):
        # Pre-LN + 自注意力 + 残差
        residual = x
        x = self.ln1(x)
        attn_out, present = self.self_attn(x, mask, past_key_value)
        x = residual + self.dropout(attn_out)
        
        # Pre-LN + FFN + 残差
        residual = x
        x = self.ln2(x)
        ffn_out = self.ffn(x)
        x = residual + ffn_out
        
        return x, present

掩码自注意力实现

class MaskedSelfAttention(nn.Module):
    """
    带KV缓存的掩码自注意力
    """
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # QKV投影
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = self.d_head ** -0.5
    
    def forward(self, x, mask=None, past_key_value=None):
        batch_size, seq_len, d_model = x.shape
        
        # 计算Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
        
        # 转置为(batch, heads, seq, d_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # KV缓存(用于推理加速)
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        present = (k, v)
        
        # 注意力计算
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # 应用因果掩码
        if mask is not None:
            attn_scores = attn_scores + mask
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 应用注意力权重
        attn_out = torch.matmul(attn_weights, v)
        
        # 重塑输出
        attn_out = attn_out.transpose(1, 2).contiguous()
        attn_out = attn_out.view(batch_size, seq_len, d_model)
        
        return self.out_proj(attn_out), present

文本生成实现

class GPTGenerator:
    """
    使用解码器进行文本生成
    """
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.model.eval()
    
    @torch.no_grad()
    def generate(
        self,
        prompt,
        max_length=100,
        temperature=0.8,
        top_p=0.9,
        repetition_penalty=1.1
    ):
        """
        自回归文本生成
        """
        # 编码输入
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
        past_key_values = None
        
        for _ in range(max_length):
            # 前向传播
            logits, past_key_values = self.model(
                input_ids[:, -1:] if past_key_values else input_ids,
                past_key_values=past_key_values
            )
            
            # 获取最后一个token的logits
            next_token_logits = logits[:, -1, :]
            
            # 应用温度
            next_token_logits = next_token_logits / temperature
            
            # 应用重复惩罚
            for token_id in set(input_ids[0].tolist()):
                next_token_logits[0, token_id] /= repetition_penalty
            
            # Top-p采样
            next_token = self.top_p_sampling(next_token_logits, top_p)
            
            # 拼接新token
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            
            # 检查是否生成结束token
            if next_token.item() == self.tokenizer.eos_token_id:
                break
        
        return self.tokenizer.decode(input_ids[0])
    
    def top_p_sampling(self, logits, top_p):
        """
        Top-p (nucleus) 采样
        """
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        
        # 找到累积概率超过top_p的位置
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        # 将要移除的token的logits设为-inf
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[0, indices_to_remove] = float('-inf')
        
        # 采样
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        return next_token

训练循环

def train_decoder_model(model, dataloader, epochs=10):
    """
    训练解码器模型(因果语言建模)
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    
    for epoch in range(epochs):
        total_loss = 0
        
        for batch in dataloader:
            input_ids = batch['input_ids']
            
            # 前向传播
            logits, _ = model(input_ids[:, :-1])
            
            # 计算损失(标签是输入左移一位)
            labels = input_ids[:, 1:]
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                labels.reshape(-1)
            )
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

性能对比

模型架构参数量上下文长度特点
GPT-2Decoder-only1.5B1024首个大规模decoder模型
GPT-3Decoder-only175B2048Few-shot能力涌现
GPT-4Decoder-only~1.7T128k多模态,长上下文
LLaMA-3Decoder-only70B8192开源,GQA优化
Claude-3Decoder-only-200k超长上下文

相关概念

延伸阅读