理解语言模型评估的核心指标——困惑度,掌握其计算方法和实际应用
PPL(W) = P(w₁, w₂, ..., wₙ)^(-1/n)
PPL = exp(-1/N × Σᵢ log P(wᵢ|w₁, ..., wᵢ₋₁))
PPL = 2^H(P,Q) = e^CE
import torch import numpy as np from torch.nn import functional as F def calculate_perplexity(model, tokenizer, text): """计算单个文本的困惑度""" # 分词 tokens = tokenizer.encode(text, return_tensors='pt') # 获取模型输出 with torch.no_grad(): outputs = model(tokens, labels=tokens) loss = outputs.loss # 交叉熵损失 # 困惑度 = exp(loss) perplexity = torch.exp(loss) return perplexity.item() # 批量计算困惑度 def calculate_perplexity_batch(model, dataloader): """在数据集上计算平均困惑度""" total_loss = 0 total_tokens = 0 model.eval() with torch.no_grad(): for batch in dataloader: outputs = model(**batch) loss = outputs.loss # 累积损失和token数 batch_size = batch['input_ids'].size(0) seq_len = batch['attention_mask'].sum().item() total_loss += loss.item() * seq_len total_tokens += seq_len # 平均损失 avg_loss = total_loss / total_tokens # 困惑度 perplexity = np.exp(avg_loss) return perplexity
# 示例:比较不同模型的困惑度 def compare_models_perplexity(models, test_text): """比较多个模型在同一文本上的困惑度""" results = {} for model_name, model in models.items(): # 计算每个句子的概率 log_probs = [] tokens = test_text.split() for i in range(1, len(tokens)): context = ' '.join(tokens[:i]) target = tokens[i] # 获取预测概率 prob = model.predict_next_token_prob(context, target) log_probs.append(np.log(prob)) # 计算困惑度 avg_log_prob = sum(log_probs) / len(log_probs) perplexity = np.exp(-avg_log_prob) results[model_name] = perplexity return results # 结果示例 # { # 'GPT-2': 35.2, # 较好 # 'GPT-3': 20.1, # 更好 # 'Random': 50000 # 很差 # }
class ModelSelector: """基于困惑度的模型选择器""" def __init__(self, candidate_models): self.models = candidate_models def select_best_model(self, validation_data): """选择困惑度最低的模型""" best_ppl = float('inf') best_model = None for name, model in self.models.items(): ppl = self.evaluate_perplexity(model, validation_data) print(f"{name}: PPL = {ppl:.2f}") if ppl < best_ppl: best_ppl = ppl best_model = name return best_model, best_ppl def evaluate_perplexity(self, model, data): """评估模型困惑度""" total_loss = 0 total_count = 0 for text in data: loss = -model.score(text) / len(text.split()) total_loss += loss total_count += 1 return np.exp(total_loss / total_count)
class PerplexityMonitor: """训练过程中的困惑度监控""" def __init__(self, patience=5): self.best_ppl = float('inf') self.patience = patience self.wait = 0 self.history = [] def update(self, epoch, train_ppl, val_ppl): """更新困惑度记录""" self.history.append({ 'epoch': epoch, 'train_ppl': train_ppl, 'val_ppl': val_ppl }) # 早停判断 if val_ppl < self.best_ppl: self.best_ppl = val_ppl self.wait = 0 return True # 保存模型 else: self.wait += 1 if self.wait >= self.patience: print(f"Early stopping at epoch {epoch}") return False # 停止训练 return None # 继续训练 def plot_history(self): """绘制困惑度曲线""" import matplotlib.pyplot as plt epochs = [h['epoch'] for h in self.history] train_ppls = [h['train_ppl'] for h in self.history] val_ppls = [h['val_ppl'] for h in self.history] plt.figure(figsize=(10, 6)) plt.plot(epochs, train_ppls, label='Train PPL') plt.plot(epochs, val_ppls, label='Val PPL') plt.xlabel('Epoch') plt.ylabel('Perplexity') plt.yscale('log') # 对数尺度 plt.legend() plt.title('Training Progress') plt.grid(True) plt.show()
def evaluate_data_quality(model, datasets): """使用困惑度评估数据集质量""" quality_scores = {} for name, dataset in datasets.items(): ppls = [] for text in dataset: ppl = calculate_perplexity(model, text) ppls.append(ppl) # 统计信息 quality_scores[name] = { 'mean_ppl': np.mean(ppls), 'std_ppl': np.std(ppls), 'median_ppl': np.median(ppls), 'outliers': sum(1 for p in ppls if p > np.mean(ppls) + 2*np.std(ppls)) } return quality_scores # 使用示例 datasets = { 'wikipedia': wiki_texts, 'reddit': reddit_texts, 'arxiv': arxiv_texts } quality = evaluate_data_quality(model, datasets) # 结果可能显示: # Wikipedia: 平均PPL=30,分布均匀 # Reddit: 平均PPL=45,方差较大 # ArXiv: 平均PPL=60,专业术语多
# 困惑度低不一定意味着任务表现好 model_a_ppl = 15.2 # 低困惑度 model_b_ppl = 18.5 # 稍高困惑度 # 但在实际任务上 model_a_accuracy = 0.82 model_b_accuracy = 0.89 # Model B实际表现更好!
# 不同分词器导致困惑度不可比 gpt_tokenizer_ppl = 25.3 bert_tokenizer_ppl = 31.2 # 不能直接比较!
def conditional_perplexity(model, context, target): """计算给定上下文的条件困惑度""" # 只计算目标部分的困惑度 full_text = context + target context_loss = model.compute_loss(context) full_loss = model.compute_loss(full_text) target_loss = full_loss - context_loss target_tokens = len(tokenize(target)) return np.exp(target_loss / target_tokens)
class DomainAdaptivePerplexity: """考虑领域特征的困惑度计算""" def __init__(self, domain_weights): self.domain_weights = domain_weights def calculate(self, model, text, domain): base_ppl = calculate_perplexity(model, text) # 根据领域调整 adjusted_ppl = base_ppl * self.domain_weights.get(domain, 1.0) return adjusted_ppl
class ComprehensiveEvaluator: """综合评估器,结合多个指标""" def evaluate(self, model, test_data): results = {} # 1. 困惑度(流畅性) results['perplexity'] = self.calculate_perplexity(model, test_data) # 2. BLEU(翻译质量) if 'translation' in test_data: results['bleu'] = self.calculate_bleu(model, test_data) # 3. ROUGE(摘要质量) if 'summarization' in test_data: results['rouge'] = self.calculate_rouge(model, test_data) # 4. 准确率(分类任务) if 'classification' in test_data: results['accuracy'] = self.calculate_accuracy(model, test_data) # 5. 人类评估 results['human_eval'] = self.collect_human_scores(model, test_data) return results
# 研究表明,困惑度与其他指标的相关性 correlations = { 'PPL vs BLEU': -0.65, # 负相关,但不完全 'PPL vs Human': -0.72, # 与人类评分有一定相关 'PPL vs Accuracy': -0.45 # 相关性较弱 }
def quick_perplexity_test(model_name, test_file): """快速测试模型困惑度""" from transformers import AutoModelForCausalLM, AutoTokenizer # 加载模型 model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # 读取测试数据 with open(test_file, 'r') as f: texts = f.readlines() # 计算困惑度 total_loss = 0 total_tokens = 0 for text in texts: inputs = tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = model(**inputs, labels=inputs['input_ids']) loss = outputs.loss total_loss += loss.item() * inputs['input_ids'].size(1) total_tokens += inputs['input_ids'].size(1) perplexity = np.exp(total_loss / total_tokens) print(f"Model: {model_name}") print(f"Perplexity: {perplexity:.2f}") return perplexity