Transformer编码器负责理解和表示输入信息,是BERT等理解型模型的核心组件
import torch
import torch.nn as nn
class TransformerEncoder(nn.Module):
"""
标准Transformer编码器实现
"""
def __init__(
self,
vocab_size=30522, # BERT词汇表大小
d_model=768, # Base:768, Large:1024
n_layers=12, # Base:12, Large:24
n_heads=12, # Base:12, Large:16
d_ff=3072, # 4 * d_model
max_seq_len=512,
dropout=0.1
):
super().__init__()
# 三种嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
self.segment_embedding = nn.Embedding(2, d_model) # 句子A/B
# 编码器层堆栈
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids, segment_ids=None, attention_mask=None):
seq_len = input_ids.size(1)
# 生成位置ID
position_ids = torch.arange(seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# 三种嵌入相加
embeddings = self.token_embedding(input_ids)
embeddings += self.position_embedding(position_ids)
if segment_ids is not None:
embeddings += self.segment_embedding(segment_ids)
x = self.dropout(embeddings)
# 通过编码器层
for layer in self.layers:
x = layer(x, attention_mask)
return self.norm(x)
class EncoderLayer(nn.Module):
"""
单个编码器层:自注意力 + FFN
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
# 多头自注意力
self.self_attention = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # BERT使用GELU激活
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差
attn_out, _ = self.self_attention(
x, x, x,
attn_mask=mask,
need_weights=False
)
x = self.norm1(x + self.dropout(attn_out))
# FFN + 残差
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
class MaskedLanguageModel(nn.Module):
"""
BERT的MLM预训练头
"""
def __init__(self, encoder, vocab_size):
super().__init__()
self.encoder = encoder
self.mlm_head = nn.Linear(encoder.d_model, vocab_size)
def forward(self, input_ids, labels=None):
# 创建掩码
masked_input, mask_indices = self.create_masks(input_ids)
# 编码
encoder_output = self.encoder(masked_input)
# 只预测被掩码的位置
masked_output = encoder_output[mask_indices]
predictions = self.mlm_head(masked_output)
if labels is not None:
loss = nn.CrossEntropyLoss()(predictions, labels[mask_indices])
return loss, predictions
return predictions
def create_masks(self, input_ids, mask_prob=0.15):
"""
创建MLM掩码
"""
batch_size, seq_len = input_ids.shape
# 随机选择15%的位置
mask_indices = torch.rand(batch_size, seq_len) < mask_prob
# 避免掩码特殊token
special_tokens = [0, 101, 102, 103] # [PAD], [CLS], [SEP], [MASK]
for token_id in special_tokens:
mask_indices &= (input_ids != token_id)
masked_input = input_ids.clone()
# 80%替换为[MASK]
mask_token = 103
random_mask = torch.rand_like(mask_indices, dtype=torch.float) < 0.8
masked_input[mask_indices & random_mask] = mask_token
# 10%替换为随机词
random_words = torch.randint(
104, self.encoder.vocab_size,
input_ids.shape,
device=input_ids.device
)
random_mask = (torch.rand_like(mask_indices, dtype=torch.float) < 0.1) & ~random_mask
masked_input[mask_indices & random_mask] = random_words[mask_indices & random_mask]
# 10%保持不变(已经是原词)
return masked_input, mask_indices
class BERTForClassification(nn.Module):
"""
使用BERT编码器进行分类
"""
def __init__(self, encoder, num_classes):
super().__init__()
self.encoder = encoder
self.classifier = nn.Sequential(
nn.Linear(encoder.d_model, encoder.d_model),
nn.Tanh(),
nn.Dropout(0.1),
nn.Linear(encoder.d_model, num_classes)
)
def forward(self, input_ids, attention_mask=None):
# 获取编码器输出
encoder_output = self.encoder(input_ids, attention_mask)
# 使用[CLS]的表示进行分类
cls_output = encoder_output[:, 0, :] # [batch_size, d_model]
# 分类
logits = self.classifier(cls_output)
return logits
# 使用示例
def sentiment_analysis_example():
"""
情感分析示例
"""
# 初始化模型
encoder = TransformerEncoder()
model = BERTForClassification(encoder, num_classes=3) # 正面/中性/负面
# 输入处理
text = "This movie is absolutely fantastic!"
input_ids = tokenize(text) # [CLS] This movie ... [SEP]
# 预测
with torch.no_grad():
logits = model(input_ids)
prediction = torch.argmax(logits, dim=-1)
return prediction # 输出:正面
特性 | BERT-Base | BERT-Large | RoBERTa | ELECTRA |
---|---|---|---|---|
层数 | 12 | 24 | 24 | 12 |
隐藏维度 | 768 | 1024 | 1024 | 768 |
注意力头 | 12 | 16 | 16 | 12 |
参数量 | 110M | 340M | 355M | 110M |
MLM准确率 | 84.3% | 86.7% | 88.5% | 89.0% |