概念定义
Transformer是一种基于自注意力机制的神经网络架构,通过并行处理序列中所有位置的信息,突破了循环神经网络的序列依赖限制,成为现代大语言模型的基石。
详细解释
Transformer架构由Google在2017年论文”Attention is All You Need”中提出,彻底改变了自然语言处理领域。其核心创新是完全抛弃了循环和卷积结构,仅依赖注意力机制来捕获输入和输出之间的依赖关系。这种设计不仅大幅提升了训练效率,还显著改善了长距离依赖的建模能力。
2024年的Transformer已经从原始设计演化出众多变体。现代架构如Llama 3采用了预归一化、分组查询注意力、旋转位置编码等优化技术。GPT-4o作为最新的多模态模型,展示了Transformer在处理文本、图像、音频等多种模态上的统一能力。Flash Attention等技术进一步将注意力计算效率提升了数个数量级。
Transformer的成功不仅限于NLP领域。Vision Transformer (ViT)在图像分类任务上超越了CNN,Sora等模型将其应用于视频生成,展现了这一架构的普适性。从BERT的双向编码到GPT的自回归生成,从单一模态到多模态融合,Transformer已成为AI时代的通用架构。
工作原理
📥 输入处理层
输入序列 → Token嵌入 + 位置编码
Token嵌入 :将词汇映射到d_model=512维向量空间
位置编码 :使用PE(pos, 2i)公式提供序列位置信息
🔄 编码器堆栈 (N=6层)
每个编码器层包含:
多头自注意力
输入 : Q=K=V (同一序列)
功能 : 并行计算所有位置的相关性
输出 : 注意力加权的特征表示
前馈神经网络
结构 : Linear → ReLU → Linear
功能 : 非线性变换增强表达能力
参数 : FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
残差连接 : 每个子层都包含残差连接和层归一化
Add & Norm: LayerNorm(x + SubLayer(x))
防止梯度消失,稳定训练过程
🎯 解码器堆栈 (N=6层)
解码器层包含三个子层:
掩码多头自注意力
防止看到未来信息(因果掩码)
确保生成的自回归特性
编码器-解码器注意力
Q来自解码器,K和V来自编码器
实现源序列和目标序列的交互
前馈网络
📤 输出生成
线性层 + Softmax → 输出概率分布将隐状态映射到词汇表大小,通过softmax得到下一个token的概率
⚡ 2024年架构优化
🔥 Flash Attention
IO优化 : 减少GPU内存访问
内存效率 : 提升100倍内存利用率
速度提升 : 推理速度提升10倍
长序列 : 支持100k+序列长度
🌀 旋转位置编码 (RoPE)
相对位置 : 编码token间的相对距离
外推能力 : 训练长度外的序列处理
广泛采用 : Llama/GPT-4/Claude等模型标配
📊 预归一化 (Pre-LN)
RMSNorm : 替代LayerNorm,计算更高效
训练稳定性 : 梯度流更稳定
收敛速度 : 训练收敛更快
🎯 分组查询注意力 (GQA)
KV缓存优化 : 减少键值对缓存大小
推理加速 : 显著提升推理效率
现代采用 : GPT-4o、Llama 3等使用
🔬 典型参数配置
模型 隐层维度 注意力头数 层数 参数量 特殊技术 GPT-3 12,288 96 96 175B 标准架构 GPT-4 ~16,384 ~128 ~120 估计1.7T MoE专家混合 Llama-3 8,192 64 80 70B GQA(8组) + RoPE Claude-3 未公开 未公开 未公开 估计200B+ 多模态统一
效率提升 : Flash Attention 2使得100k+序列长度处理成为现实,推动了长上下文模型的发展
🌐 2024多模态统一架构
文本 + 图像 + 音频的端到端处理
单一神经网络架构,无需模态特定组件
实时多模态交互能力
Vision Transformer - 图像领域突破
基于时空Transformer的视频生成
Patch-based视频表示
扩散模型与Transformer的深度结合
自注意力机制 :并行计算序列中所有位置的相关性
多头注意力 :从不同表示子空间捕获信息
位置编码 :为模型提供序列顺序信息
编码器-解码器 :分离理解和生成任务
实际应用
import torch
import torch.nn as nn
import math
class TransformerBlock ( nn . Module ):
"""
标准Transformer块实现
"""
def __init__ ( self , d_model = 512 , n_heads = 8 , d_ff = 2048 , dropout = 0.1 ):
super (). __init__ ()
# 多头注意力
self .attention = nn.MultiheadAttention(
d_model, n_heads, dropout = dropout
)
# 前馈网络
self .feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# 层归一化
self .norm1 = nn.LayerNorm(d_model)
self .norm2 = nn.LayerNorm(d_model)
self .dropout = nn.Dropout(dropout)
def forward ( self , x , mask = None ):
# 自注意力 + 残差连接
attn_out, _ = self .attention(x, x, x, attn_mask = mask)
x = self .norm1(x + self .dropout(attn_out))
# 前馈网络 + 残差连接
ff_out = self .feed_forward(x)
x = self .norm2(x + self .dropout(ff_out))
return x
2024年优化实现
class ModernTransformer ( nn . Module ):
"""
包含2024年优化的Transformer
"""
def __init__ ( self , config ):
super (). __init__ ()
# 使用RMSNorm代替LayerNorm (Llama风格)
self .norm = RMSNorm(config.d_model)
# 分组查询注意力 (GQA)
self .attention = GroupedQueryAttention(
d_model = config.d_model,
n_heads = config.n_heads,
n_kv_heads = config.n_kv_heads, # KV头数量少于Q头
use_flash_attn = True # 使用Flash Attention
)
# SwiGLU激活的FFN (GPT-4/Llama风格)
self .feed_forward = SwiGLUFFN(
d_model = config.d_model,
d_ff = config.d_ff
)
# 旋转位置编码
self .rope = RotaryPositionalEncoding(
d_model = config.d_model,
max_seq_len = config.max_seq_len
)
def forward ( self , x , kv_cache = None ):
# 预归一化 (Pre-LN)
x_norm = self .norm(x)
# 应用RoPE
x_with_pos = self .rope(x_norm)
# GQA with Flash Attention
attn_out, new_kv_cache = self .attention(
x_with_pos,
kv_cache = kv_cache,
use_flash = True
)
# 残差连接
x = x + attn_out
# FFN with pre-norm
x = x + self .feed_forward( self .norm(x))
return x, new_kv_cache
位置编码演进
# 1. 原始正弦位置编码 (2017)
def sinusoidal_position_encoding ( seq_len , d_model ):
pe = torch.zeros(seq_len, d_model)
position = torch.arange(seq_len).unsqueeze( 1 )
div_term = torch.exp(torch.arange( 0 , d_model, 2 ) *
- (math.log( 10000.0 ) / d_model))
pe[:, 0 :: 2 ] = torch.sin(position * div_term)
pe[:, 1 :: 2 ] = torch.cos(position * div_term)
return pe
# 2. 旋转位置编码 RoPE (2024主流)
class RotaryPositionalEncoding :
def __init__ ( self , dim , max_seq_len = 8192 , base = 10000 ):
self .dim = dim
self .max_seq_len = max_seq_len
self .base = base
def forward ( self , q , k ):
# 应用旋转矩阵编码相对位置
seq_len = q.shape[ 1 ]
# 计算旋转角度
theta = self .compute_theta(seq_len)
# 旋转查询和键
q_rot = self .rotate(q, theta)
k_rot = self .rotate(k, theta)
return q_rot, k_rot
class MultiModalTransformer :
"""
GPT-4o风格的多模态Transformer
"""
def __init__ ( self , config ):
# 统一的Transformer主干
self .transformer = ModernTransformer(config)
# 不同模态的编码器
self .text_encoder = TextTokenizer(config.vocab_size)
self .image_encoder = VisionTransformer(config.image_size)
self .audio_encoder = AudioTransformer(config.audio_dim)
# 统一的嵌入空间
self .project_text = nn.Linear(config.text_dim, config.d_model)
self .project_image = nn.Linear(config.image_dim, config.d_model)
self .project_audio = nn.Linear(config.audio_dim, config.d_model)
def forward ( self , text = None , image = None , audio = None ):
embeddings = []
if text is not None :
text_emb = self .project_text( self .text_encoder(text))
embeddings.append(text_emb)
if image is not None :
image_emb = self .project_image( self .image_encoder(image))
embeddings.append(image_emb)
if audio is not None :
audio_emb = self .project_audio( self .audio_encoder(audio))
embeddings.append(audio_emb)
# 拼接所有模态
multi_modal_input = torch.cat(embeddings, dim = 1 )
# 通过统一的Transformer处理
output = self .transformer(multi_modal_input)
return output
性能对比
架构特性 原始Transformer (2017) GPT-3 (2020) GPT-4/Llama 3 (2024) 注意力机制 标准多头 标准多头 GQA/Flash Attention 位置编码 正弦编码 学习编码 RoPE 归一化 Post-LN Pre-LN RMSNorm + Pre-LN 激活函数 ReLU GELU SwiGLU 最大序列长度 512 2048 128k+ 推理速度 基准 1x 10x+ (Flash Attn)
相关概念
延伸阅读