概念定义

自监督学习(Self-supervised Learning)是一种机器学习范式,通过从数据本身构建监督信号来学习表示,无需人工标注。它是大语言模型预训练的核心技术,通过设计预测任务让模型学习语言的结构和语义。

详细解释

什么是自监督学习?

自监督学习巧妙地将无监督问题转化为监督问题:通过遮挡或变换输入数据的一部分,让模型预测被遮挡的内容,从而创造出”免费”的监督信号。 核心思想
  • 自创标签:从数据本身生成训练标签
  • 预测任务:设计合理的预测目标
  • 表示学习:学习数据的内在表示
  • 无需标注:充分利用大规模无标签数据
与其他学习范式的关系
  • 监督学习:使用人工标注的标签
  • 无监督学习:发现数据中的模式和结构
  • 自监督学习:创造监督信号,介于两者之间
形象比喻自监督学习就像一个学生通过做填空题来学习:传统监督学习:老师直接告诉答案(人工标注) 无监督学习:学生自己观察和总结规律(聚类、降维) 自监督学习:学生做填空题,通过预测空白处来学习(掩码预测)例如:看到”天空是_色的”,通过预测”蓝”字来学习颜色和物体的关系。

发展历程

早期探索(2010-2017)
  • 词向量模型(Word2Vec、GloVe)
  • 基于上下文的表示学习
  • 简单的预测任务
深度学习时代(2018-2020)
  • BERT的掩码语言建模(MLM)
  • GPT的因果语言建模(CLM)
  • 视觉领域的对比学习
大模型时代(2020至今)
  • 更大规模的预训练
  • 多模态自监督学习
  • 推理能力的涌现

技术原理

掩码语言建模(MLM)

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertForMaskedLM
import random

class MaskedLanguageModeling:
    """掩码语言建模实现"""
    
    def __init__(self, tokenizer, mask_prob=0.15):
        self.tokenizer = tokenizer
        self.mask_prob = mask_prob
        self.mask_token_id = tokenizer.mask_token_id
        
    def create_masked_input(self, text):
        """创建掩码输入"""
        # 分词
        tokens = self.tokenizer.tokenize(text)
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        
        # 创建标签(原始token_ids)
        labels = token_ids.copy()
        
        # 掩码策略
        masked_indices = []
        for i, token_id in enumerate(token_ids):
            if random.random() < self.mask_prob:
                masked_indices.append(i)
                
                # BERT掩码策略
                rand = random.random()
                if rand < 0.8:
                    # 80%: 替换为[MASK]
                    token_ids[i] = self.mask_token_id
                elif rand < 0.9:
                    # 10%: 替换为随机token
                    token_ids[i] = random.randint(0, self.tokenizer.vocab_size - 1)
                # 10%: 保持不变
        
        # 创建attention mask
        attention_mask = [1] * len(token_ids)
        
        # 只计算被掩码位置的损失
        loss_mask = [1 if i in masked_indices else 0 for i in range(len(labels))]
        
        return {
            'input_ids': torch.tensor(token_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels),
            'loss_mask': torch.tensor(loss_mask)
        }
    
    def mlm_loss(self, model, batch):
        """计算MLM损失"""
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        # 获取预测logits
        logits = outputs.logits
        
        # 只计算掩码位置的损失
        active_loss = batch['loss_mask'].view(-1) == 1
        active_logits = logits.view(-1, logits.shape[-1])[active_loss]
        active_labels = batch['labels'].view(-1)[active_loss]
        
        loss = F.cross_entropy(active_logits, active_labels)
        
        return loss, logits
    
    def predict_masked_tokens(self, model, text):
        """预测掩码位置的词"""
        model.eval()
        
        # 创建掩码输入
        batch = self.create_masked_input(text)
        
        with torch.no_grad():
            outputs = model(
                input_ids=batch['input_ids'].unsqueeze(0),
                attention_mask=batch['attention_mask'].unsqueeze(0)
            )
            
            predictions = torch.argmax(outputs.logits, dim=-1)
        
        # 解码预测结果
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(
            predictions[0].tolist()
        )
        
        # 显示结果
        original_tokens = self.tokenizer.convert_ids_to_tokens(
            batch['labels'].tolist()
        )
        
        results = []
        for i, (orig, pred) in enumerate(zip(original_tokens, predicted_tokens)):
            if batch['loss_mask'][i] == 1:
                results.append({
                    'position': i,
                    'original': orig,
                    'predicted': pred,
                    'correct': orig == pred
                })
        
        return results

# 使用示例
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForMaskedLM.from_pretrained('bert-base-chinese')

mlm = MaskedLanguageModeling(tokenizer)
text = "今天天气很好,适合出门散步。"

# 创建掩码输入并预测
results = mlm.predict_masked_tokens(model, text)
print("掩码预测结果:")
for result in results:
    print(f"位置{result['position']}: {result['original']} -> {result['predicted']} ({'✓' if result['correct'] else '✗'})")

因果语言建模(CLM)

class CausalLanguageModeling:
    """因果语言建模(GPT风格)"""
    
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def prepare_clm_data(self, texts):
        """准备CLM训练数据"""
        dataset = []
        
        for text in texts:
            # 分词
            tokens = self.tokenizer.encode(text, truncation=True, max_length=self.max_length)
            
            if len(tokens) > 1:
                # 输入:除最后一个token
                input_ids = tokens[:-1]
                # 标签:除第一个token(向右偏移)
                labels = tokens[1:]
                
                dataset.append({
                    'input_ids': torch.tensor(input_ids),
                    'labels': torch.tensor(labels)
                })
        
        return dataset
    
    def clm_loss(self, model, batch):
        """计算CLM损失"""
        outputs = model(
            input_ids=batch['input_ids'],
            labels=batch['labels']
        )
        
        return outputs.loss
    
    def generate_text(self, model, prompt, max_length=100, temperature=0.8):
        """生成文本"""
        model.eval()
        
        # 编码提示
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
        
        # 生成
        with torch.no_grad():
            for _ in range(max_length):
                outputs = model(input_ids=input_ids)
                logits = outputs.logits[0, -1, :] / temperature
                
                # 采样下一个token
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                # 添加到序列
                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                
                # 检查结束符
                if next_token.item() == self.tokenizer.eos_token_id:
                    break
        
        # 解码生成的文本
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text[len(prompt):]  # 返回新生成的部分
    
    def perplexity_evaluation(self, model, test_texts):
        """困惑度评估"""
        model.eval()
        total_loss = 0
        total_tokens = 0
        
        with torch.no_grad():
            for text in test_texts:
                tokens = self.tokenizer.encode(text, return_tensors='pt')
                
                if tokens.shape[1] > 1:
                    input_ids = tokens[:, :-1]
                    labels = tokens[:, 1:]
                    
                    outputs = model(input_ids=input_ids, labels=labels)
                    loss = outputs.loss
                    
                    total_loss += loss.item() * labels.numel()
                    total_tokens += labels.numel()
        
        avg_loss = total_loss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss))
        
        return perplexity.item()

# 训练示例
def train_clm_model(model, dataset, num_epochs=3):
    """训练CLM模型"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    
    for epoch in range(num_epochs):
        total_loss = 0
        
        for batch in dataset:
            # 前向传播
            loss = clm.clm_loss(model, batch)
            total_loss += loss.item()
            
            # 反向传播
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / len(dataset)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    
    return model

对比学习

class ContrastiveLearning:
    """对比学习实现"""
    
    def __init__(self, encoder, temperature=0.07):
        self.encoder = encoder
        self.temperature = temperature
        
    def create_positive_pairs(self, texts):
        """创建正样本对"""
        pairs = []
        
        for text in texts:
            # 数据增强策略
            augmented_versions = [
                self.random_mask_augment(text),
                self.synonym_replacement(text),
                self.random_deletion(text),
                self.back_translation(text)
            ]
            
            # 创建正样本对
            for i in range(len(augmented_versions)):
                for j in range(i+1, len(augmented_versions)):
                    pairs.append((augmented_versions[i], augmented_versions[j]))
        
        return pairs
    
    def random_mask_augment(self, text, mask_ratio=0.1):
        """随机掩码增强"""
        words = text.split()
        n_mask = max(1, int(len(words) * mask_ratio))
        
        mask_indices = random.sample(range(len(words)), n_mask)
        for idx in mask_indices:
            words[idx] = '[MASK]'
        
        return ' '.join(words)
    
    def synonym_replacement(self, text):
        """同义词替换"""
        # 简化实现,实际可以使用WordNet或同义词词典
        synonyms = {
            '好': ['棒', '不错', '优秀'],
            '坏': ['差', '糟糕', '不好'],
            '大': ['巨大', '庞大', '巨型'],
            '小': ['微小', '细小', '迷你']
        }
        
        words = text.split()
        for i, word in enumerate(words):
            if word in synonyms and random.random() < 0.3:
                words[i] = random.choice(synonyms[word])
        
        return ' '.join(words)
    
    def random_deletion(self, text, delete_prob=0.1):
        """随机删除"""
        words = text.split()
        if len(words) == 1:
            return text
        
        new_words = []
        for word in words:
            if random.random() > delete_prob:
                new_words.append(word)
        
        if len(new_words) == 0:
            return random.choice(words)
        
        return ' '.join(new_words)
    
    def back_translation(self, text):
        """回译(简化版)"""
        # 实际实现中会使用机器翻译API
        # 这里只是示意
        return text + " [back-translated]"
    
    def simcse_loss(self, embeddings1, embeddings2):
        """SimCSE对比损失"""
        # 归一化嵌入
        embeddings1 = F.normalize(embeddings1, dim=1)
        embeddings2 = F.normalize(embeddings2, dim=1)
        
        batch_size = embeddings1.shape[0]
        
        # 拼接嵌入
        embeddings = torch.cat([embeddings1, embeddings2], dim=0)
        
        # 计算相似度矩阵
        sim_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature
        
        # 创建标签(对角线为正样本)
        labels = torch.arange(batch_size * 2)
        labels = torch.cat([labels[batch_size:], labels[:batch_size]])
        
        # 掩码自己
        mask = torch.eye(batch_size * 2, dtype=torch.bool)
        sim_matrix = sim_matrix.masked_fill(mask, -1e9)
        
        # 计算损失
        loss = F.cross_entropy(sim_matrix, labels)
        
        return loss
    
    def momentum_update(self, online_encoder, momentum_encoder, tau=0.999):
        """动量更新(MoCo风格)"""
        for online_param, momentum_param in zip(
            online_encoder.parameters(),
            momentum_encoder.parameters()
        ):
            momentum_param.data = tau * momentum_param.data + (1 - tau) * online_param.data

# 训练对比学习模型
def train_contrastive_model(model, texts, num_epochs=10):
    """训练对比学习模型"""
    contrastive = ContrastiveLearning(model.encoder)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
    
    for epoch in range(num_epochs):
        # 创建正样本对
        positive_pairs = contrastive.create_positive_pairs(texts)
        
        total_loss = 0
        for text1, text2 in positive_pairs:
            # 编码
            emb1 = model.encode(text1)
            emb2 = model.encode(text2)
            
            # 对比损失
            loss = contrastive.simcse_loss(emb1.unsqueeze(0), emb2.unsqueeze(0))
            total_loss += loss.item()
            
            # 反向传播
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / len(positive_pairs)
        print(f"Epoch {epoch+1}, Contrastive Loss: {avg_loss:.4f}")
    
    return model

2024年最新技术

推理模型与思维链

class ReasoningModel:
    """推理模型(2024年OpenAI o1风格)"""
    
    def __init__(self, base_model, reasoning_steps=10):
        self.base_model = base_model
        self.reasoning_steps = reasoning_steps
        
    def generate_reasoning_chain(self, query):
        """生成推理链"""
        reasoning_chain = []
        current_context = f"问题: {query}\n让我一步步思考:\n"
        
        for step in range(self.reasoning_steps):
            # 生成推理步骤
            step_prompt = f"{current_context}\n{step+1}步思考:"
            
            step_reasoning = self.base_model.generate(
                step_prompt,
                max_length=100,
                temperature=0.7,
                stop=["\n第", "\n最终答案"]
            )
            
            reasoning_chain.append(step_reasoning)
            current_context += f"\n{step+1}步思考: {step_reasoning}"
            
            # 检查是否完成推理
            if self.is_reasoning_complete(step_reasoning, query):
                break
        
        return reasoning_chain
    
    def is_reasoning_complete(self, reasoning, original_query):
        """判断推理是否完成"""
        completion_indicators = [
            "因此", "所以", "综上所述", "答案是",
            "可以得出", "结论是"
        ]
        
        return any(indicator in reasoning for indicator in completion_indicators)
    
    def generate_final_answer(self, query, reasoning_chain):
        """基于推理链生成最终答案"""
        full_reasoning = "\n".join([
            f"第{i+1}步思考: {step}" 
            for i, step in enumerate(reasoning_chain)
        ])
        
        final_prompt = f"""
        问题: {query}
        
        推理过程:
        {full_reasoning}
        
        基于以上推理,最终答案是:
        """
        
        final_answer = self.base_model.generate(
            final_prompt,
            max_length=200,
            temperature=0.3
        )
        
        return final_answer
    
    def self_supervised_reasoning_training(self, reasoning_dataset):
        """自监督推理训练"""
        optimizer = torch.optim.AdamW(self.base_model.parameters(), lr=1e-5)
        
        for epoch in range(5):
            total_loss = 0
            
            for item in reasoning_dataset:
                query = item['query']
                correct_reasoning = item['reasoning']
                correct_answer = item['answer']
                
                # 生成推理链
                generated_reasoning = self.generate_reasoning_chain(query)
                generated_answer = self.generate_final_answer(query, generated_reasoning)
                
                # 计算推理损失(与正确推理的相似度)
                reasoning_loss = self.reasoning_similarity_loss(
                    generated_reasoning, correct_reasoning
                )
                
                # 计算答案损失
                answer_loss = self.answer_loss(generated_answer, correct_answer)
                
                # 总损失
                total_loss_item = reasoning_loss + answer_loss
                total_loss += total_loss_item.item()
                
                # 反向传播
                total_loss_item.backward()
                optimizer.step()
                optimizer.zero_grad()
            
            avg_loss = total_loss / len(reasoning_dataset)
            print(f"Reasoning Epoch {epoch+1}, Loss: {avg_loss:.4f}")
        
        return self.base_model

多模态自监督学习

class MultimodalSelfSupervised:
    """多模态自监督学习"""
    
    def __init__(self, vision_encoder, text_encoder):
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder
        
    def clip_style_training(self, image_text_pairs):
        """CLIP风格的对比训练"""
        optimizer = torch.optim.AdamW(
            list(self.vision_encoder.parameters()) + 
            list(self.text_encoder.parameters()),
            lr=3e-4
        )
        
        for epoch in range(10):
            total_loss = 0
            
            for batch in image_text_pairs:
                images = batch['images']
                texts = batch['texts']
                
                # 编码
                image_features = self.vision_encoder(images)
                text_features = self.text_encoder(texts)
                
                # 归一化
                image_features = F.normalize(image_features, dim=1)
                text_features = F.normalize(text_features, dim=1)
                
                # 计算相似度矩阵
                logits_per_image = torch.matmul(image_features, text_features.T)
                logits_per_text = logits_per_image.T
                
                # 创建标签(对角线为正样本)
                batch_size = images.shape[0]
                labels = torch.arange(batch_size)
                
                # 对比损失
                loss_img = F.cross_entropy(logits_per_image, labels)
                loss_txt = F.cross_entropy(logits_per_text, labels)
                loss = (loss_img + loss_txt) / 2
                
                total_loss += loss.item()
                
                # 反向传播
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            
            avg_loss = total_loss / len(image_text_pairs)
            print(f"Multimodal Epoch {epoch+1}, Loss: {avg_loss:.4f}")
        
        return self.vision_encoder, self.text_encoder
    
    def masked_image_modeling(self, images, mask_ratio=0.75):
        """掩码图像建模(MAE风格)"""
        batch_size, channels, height, width = images.shape
        
        # 将图像分割为patches
        patch_size = 16
        num_patches = (height // patch_size) * (width // patch_size)
        
        # 随机掩码
        len_keep = int(num_patches * (1 - mask_ratio))
        noise = torch.rand(batch_size, num_patches)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # 保留可见patches
        ids_keep = ids_shuffle[:, :len_keep]
        
        # 提取patches
        patches = self.extract_patches(images, patch_size)
        x_masked = torch.gather(
            patches, 
            dim=1,
            index=ids_keep.unsqueeze(-1).repeat(1, 1, patches.shape[-1])
        )
        
        # 编码
        encoded = self.vision_encoder.encode_patches(x_masked)
        
        # 解码(重构所有patches)
        decoded = self.vision_encoder.decode_patches(encoded, ids_restore)
        
        # 只计算掩码位置的损失
        mask = torch.ones([batch_size, num_patches])
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        loss = (decoded - patches) ** 2
        loss = loss.mean(dim=-1)  # 每个patch的平均损失
        loss = (loss * mask).sum() / mask.sum()  # 只计算掩码位置
        
        return loss
    
    def extract_patches(self, images, patch_size):
        """提取图像patches"""
        batch_size, channels, height, width = images.shape
        num_patches_h = height // patch_size
        num_patches_w = width // patch_size
        
        patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        patches = patches.contiguous().view(
            batch_size, channels, 
            num_patches_h * num_patches_w, 
            patch_size, patch_size
        )
        patches = patches.permute(0, 2, 1, 3, 4).contiguous()
        patches = patches.view(batch_size, -1, channels * patch_size * patch_size)
        
        return patches

高效预训练技术

class EfficientPretraining:
    """高效预训练技术(2024年优化)"""
    
    def __init__(self, model):
        self.model = model
        
    def gradient_checkpointing(self, forward_fn, *args):
        """梯度检查点(节省显存)"""
        class CheckpointFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, forward_fn, *args):
                ctx.forward_fn = forward_fn
                ctx.args = args
                
                # 不保存中间激活,只保存输入
                with torch.no_grad():
                    outputs = forward_fn(*args)
                
                return outputs
            
            @staticmethod
            def backward(ctx, *grad_outputs):
                # 重新计算前向传播以获得中间激活
                args = ctx.args
                for arg in args:
                    arg.requires_grad_(True)
                
                with torch.enable_grad():
                    outputs = ctx.forward_fn(*args)
                
                # 计算梯度
                torch.autograd.backward(outputs, grad_outputs)
                
                return (None,) + tuple(arg.grad for arg in args)
        
        return CheckpointFunction.apply(forward_fn, *args)
    
    def mixed_precision_training(self, model, data_loader):
        """混合精度训练"""
        from torch.cuda.amp import autocast, GradScaler
        
        scaler = GradScaler()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        
        for epoch in range(10):
            for batch in data_loader:
                optimizer.zero_grad()
                
                # 自动混合精度
                with autocast():
                    outputs = model(**batch)
                    loss = outputs.loss
                
                # 缩放损失并反向传播
                scaler.scale(loss).backward()
                
                # 更新参数
                scaler.step(optimizer)
                scaler.update()
        
        return model
    
    def dynamic_batching(self, dataset, max_tokens=8192):
        """动态批处理(按token数量)"""
        # 按序列长度排序
        sorted_dataset = sorted(dataset, key=lambda x: len(x['input_ids']))
        
        batches = []
        current_batch = []
        current_tokens = 0
        
        for item in sorted_dataset:
            item_tokens = len(item['input_ids'])
            
            # 检查是否超过限制
            if current_tokens + item_tokens > max_tokens and current_batch:
                batches.append(current_batch)
                current_batch = [item]
                current_tokens = item_tokens
            else:
                current_batch.append(item)
                current_tokens += item_tokens
        
        # 添加最后一个批次
        if current_batch:
            batches.append(current_batch)
        
        return batches
    
    def curriculum_learning(self, model, dataset, num_epochs=10):
        """课程学习(从简单到困难)"""
        # 按难度排序数据
        def difficulty_score(item):
            # 简单的难度评估:序列长度 + 词汇复杂度
            text_len = len(item['input_ids'])
            vocab_diversity = len(set(item['input_ids']))
            return text_len + vocab_diversity
        
        sorted_dataset = sorted(dataset, key=difficulty_score)
        
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
        
        for epoch in range(num_epochs):
            # 逐渐增加数据难度
            data_ratio = min(1.0, (epoch + 1) / num_epochs * 1.5)
            data_size = int(len(sorted_dataset) * data_ratio)
            current_data = sorted_dataset[:data_size]
            
            print(f"Epoch {epoch+1}: Using {data_size}/{len(sorted_dataset)} samples")
            
            total_loss = 0
            for batch in current_data:
                outputs = model(**batch)
                loss = outputs.loss
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(current_data)
            print(f"Average Loss: {avg_loss:.4f}")
        
        return model
自监督学习最佳实践
  1. 预训练任务设计:选择合适的预测任务,平衡难度和信息量
  2. 数据增强策略:使用多样化的增强方法提高模型鲁棒性
  3. 负采样技巧:在对比学习中使用困难负样本
  4. 学习率调度:使用warmup和衰减策略
  5. 批次大小优化:根据任务特点调整批次大小
  6. 多任务学习:结合多种自监督任务提升效果

实际应用

预训练流程

class PretrainingPipeline:
    """完整预训练流程"""
    
    def __init__(self, model_config):
        self.model = self.build_model(model_config)
        
    def build_model(self, config):
        """构建模型"""
        if config['model_type'] == 'bert':
            from transformers import BertConfig, BertForMaskedLM
            model_config = BertConfig(**config['model_params'])
            model = BertForMaskedLM(model_config)
        elif config['model_type'] == 'gpt':
            from transformers import GPT2Config, GPT2LMHeadModel
            model_config = GPT2Config(**config['model_params'])
            model = GPT2LMHeadModel(model_config)
        
        return model
    
    def prepare_data(self, raw_texts, tokenizer):
        """数据预处理"""
        processed_data = []
        
        for text in raw_texts:
            # 清理文本
            cleaned_text = self.clean_text(text)
            
            # 分词
            tokens = tokenizer.encode(
                cleaned_text,
                truncation=True,
                max_length=512,
                padding='max_length'
            )
            
            processed_data.append({
                'input_ids': torch.tensor(tokens),
                'text': cleaned_text
            })
        
        return processed_data
    
    def clean_text(self, text):
        """文本清理"""
        import re
        
        # 移除特殊字符
        text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
        
        # 移除多余空格
        text = re.sub(r'\s+', ' ', text).strip()
        
        # 过滤过短文本
        if len(text.split()) < 5:
            return ""
        
        return text
    
    def train(self, dataset, num_epochs=3, save_steps=1000):
        """训练模型"""
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=5e-5,
            weight_decay=0.01
        )
        
        # 学习率调度
        from transformers import get_linear_schedule_with_warmup
        
        total_steps = len(dataset) * num_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )
        
        step = 0
        for epoch in range(num_epochs):
            total_loss = 0
            
            for batch in dataset:
                # 前向传播
                outputs = self.model(**batch)
                loss = outputs.loss
                
                # 反向传播
                loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                
                # 更新参数
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
                total_loss += loss.item()
                step += 1
                
                # 保存检查点
                if step % save_steps == 0:
                    self.save_checkpoint(step)
                
                # 记录日志
                if step % 100 == 0:
                    print(f"Step {step}, Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / len(dataset)
            print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
        
        return self.model
    
    def save_checkpoint(self, step):
        """保存检查点"""
        checkpoint = {
            'step': step,
            'model_state_dict': self.model.state_dict(),
            'config': self.model.config
        }
        
        torch.save(checkpoint, f'checkpoint_step_{step}.pt')
        print(f"Checkpoint saved at step {step}")
    
    def evaluate(self, eval_dataset):
        """评估模型"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch in eval_dataset:
                outputs = self.model(**batch)
                loss = outputs.loss
                total_loss += loss.item()
        
        avg_loss = total_loss / len(eval_dataset)
        perplexity = torch.exp(torch.tensor(avg_loss))
        
        print(f"Evaluation - Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")
        
        return {'loss': avg_loss, 'perplexity': perplexity.item()}

# 使用示例
config = {
    'model_type': 'bert',
    'model_params': {
        'vocab_size': 30522,
        'hidden_size': 768,
        'num_hidden_layers': 12,
        'num_attention_heads': 12,
        'max_position_embeddings': 512
    }
}

pipeline = PretrainingPipeline(config)

# 准备数据
raw_texts = ["这是第一段文本...", "这是第二段文本..."]
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
dataset = pipeline.prepare_data(raw_texts, tokenizer)

# 训练
trained_model = pipeline.train(dataset, num_epochs=3)

下游任务适配

class DownstreamAdaptation:
    """下游任务适配"""
    
    def __init__(self, pretrained_model):
        self.pretrained_model = pretrained_model
        
    def text_classification_adaptation(self, num_classes=2):
        """文本分类适配"""
        # 冻结预训练参数
        for param in self.pretrained_model.parameters():
            param.requires_grad = False
        
        # 添加分类头
        classifier = nn.Sequential(
            nn.Linear(self.pretrained_model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        
        # 组合模型
        class ClassificationModel(nn.Module):
            def __init__(self, pretrained, classifier):
                super().__init__()
                self.pretrained = pretrained
                self.classifier = classifier
                
            def forward(self, input_ids, attention_mask=None):
                # 获取预训练表示
                outputs = self.pretrained(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                
                # 使用CLS token的表示
                cls_representation = outputs.last_hidden_state[:, 0, :]
                
                # 分类
                logits = self.classifier(cls_representation)
                
                return logits
        
        return ClassificationModel(self.pretrained_model, classifier)
    
    def named_entity_recognition_adaptation(self, num_labels=9):
        """命名实体识别适配"""
        # 只微调最后几层
        layers_to_finetune = 2
        
        for i, layer in enumerate(self.pretrained_model.encoder.layer):
            if i < len(self.pretrained_model.encoder.layer) - layers_to_finetune:
                for param in layer.parameters():
                    param.requires_grad = False
        
        # 添加NER头
        ner_head = nn.Linear(
            self.pretrained_model.config.hidden_size,
            num_labels
        )
        
        class NERModel(nn.Module):
            def __init__(self, pretrained, ner_head):
                super().__init__()
                self.pretrained = pretrained
                self.ner_head = ner_head
                
            def forward(self, input_ids, attention_mask=None, labels=None):
                outputs = self.pretrained(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                
                # 获取每个token的表示
                sequence_output = outputs.last_hidden_state
                
                # NER预测
                logits = self.ner_head(sequence_output)
                
                loss = None
                if labels is not None:
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
                
                return {'loss': loss, 'logits': logits}
        
        return NERModel(self.pretrained_model, ner_head)
    
    def question_answering_adaptation(self):
        """问答任务适配"""
        # 添加QA头(开始和结束位置预测)
        qa_head = nn.Linear(
            self.pretrained_model.config.hidden_size,
            2  # start和end位置
        )
        
        class QAModel(nn.Module):
            def __init__(self, pretrained, qa_head):
                super().__init__()
                self.pretrained = pretrained
                self.qa_head = qa_head
                
            def forward(self, input_ids, attention_mask=None, 
                       start_positions=None, end_positions=None):
                outputs = self.pretrained(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                
                sequence_output = outputs.last_hidden_state
                logits = self.qa_head(sequence_output)
                
                start_logits, end_logits = logits.split(1, dim=-1)
                start_logits = start_logits.squeeze(-1)
                end_logits = end_logits.squeeze(-1)
                
                total_loss = None
                if start_positions is not None and end_positions is not None:
                    loss_fct = nn.CrossEntropyLoss()
                    start_loss = loss_fct(start_logits, start_positions)
                    end_loss = loss_fct(end_logits, end_positions)
                    total_loss = (start_loss + end_loss) / 2
                
                return {
                    'loss': total_loss,
                    'start_logits': start_logits,
                    'end_logits': end_logits
                }
        
        return QAModel(self.pretrained_model, qa_head)

评估与分析

class PretrainingEvaluator:
    """预训练评估器"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def intrinsic_evaluation(self, test_texts):
        """内在评估(困惑度等)"""
        results = {}
        
        # 困惑度
        clm = CausalLanguageModeling(self.tokenizer)
        perplexity = clm.perplexity_evaluation(self.model, test_texts)
        results['perplexity'] = perplexity
        
        # 掩码预测准确率(如果是BERT类模型)
        if hasattr(self.model, 'cls'):
            mlm = MaskedLanguageModeling(self.tokenizer)
            mask_accuracy = self.evaluate_mask_prediction(mlm, test_texts)
            results['mask_accuracy'] = mask_accuracy
        
        return results
    
    def evaluate_mask_prediction(self, mlm, test_texts, num_samples=100):
        """评估掩码预测准确率"""
        correct = 0
        total = 0
        
        for text in test_texts[:num_samples]:
            predictions = mlm.predict_masked_tokens(self.model, text)
            
            for pred in predictions:
                total += 1
                if pred['correct']:
                    correct += 1
        
        accuracy = correct / total if total > 0 else 0
        return accuracy
    
    def probing_tasks(self, probing_datasets):
        """探测任务评估"""
        results = {}
        
        for task_name, dataset in probing_datasets.items():
            print(f"Evaluating {task_name}...")
            
            # 提取表示
            representations = []
            labels = []
            
            for item in dataset:
                text = item['text']
                label = item['label']
                
                # 获取模型表示
                inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
                
                with torch.no_grad():
                    outputs = self.model(**inputs)
                    # 使用CLS token或平均池化
                    if hasattr(outputs, 'last_hidden_state'):
                        representation = outputs.last_hidden_state[:, 0, :].squeeze()
                    else:
                        representation = outputs.hidden_states[-1][:, 0, :].squeeze()
                
                representations.append(representation.numpy())
                labels.append(label)
            
            # 训练线性探测器
            from sklearn.linear_model import LogisticRegression
            from sklearn.metrics import accuracy_score
            from sklearn.model_selection import train_test_split
            
            X_train, X_test, y_train, y_test = train_test_split(
                representations, labels, test_size=0.2, random_state=42
            )
            
            probe = LogisticRegression(max_iter=1000)
            probe.fit(X_train, y_train)
            
            y_pred = probe.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            
            results[task_name] = accuracy
            print(f"{task_name} accuracy: {accuracy:.3f}")
        
        return results
    
    def downstream_transfer_evaluation(self, downstream_tasks):
        """下游任务迁移评估"""
        results = {}
        
        for task_name, task_data in downstream_tasks.items():
            print(f"Evaluating transfer to {task_name}...")
            
            # 适配下游任务
            adapter = DownstreamAdaptation(self.model)
            
            if task_name == 'classification':
                adapted_model = adapter.text_classification_adaptation(
                    task_data['num_classes']
                )
            elif task_name == 'ner':
                adapted_model = adapter.named_entity_recognition_adaptation(
                    task_data['num_labels']
                )
            elif task_name == 'qa':
                adapted_model = adapter.question_answering_adaptation()
            
            # 微调和评估
            task_performance = self.finetune_and_evaluate(
                adapted_model, task_data
            )
            
            results[task_name] = task_performance
        
        return results
    
    def finetune_and_evaluate(self, model, task_data):
        """微调并评估"""
        # 简化的微调过程
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
        
        # 训练
        model.train()
        for epoch in range(3):
            for batch in task_data['train']:
                outputs = model(**batch)
                loss = outputs['loss'] if isinstance(outputs, dict) else outputs.loss
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        
        # 评估
        model.eval()
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch in task_data['test']:
                outputs = model(**batch)
                predictions = torch.argmax(outputs['logits'], dim=-1)
                
                correct = (predictions == batch['labels']).sum().item()
                total_correct += correct
                total_samples += batch['labels'].size(0)
        
        accuracy = total_correct / total_samples
        return accuracy
自监督学习注意事项
  1. 计算资源:预训练需要大量计算资源和时间
  2. 数据质量:高质量的预训练数据对模型性能至关重要
  3. 任务设计:预训练任务应与下游任务相关
  4. 过拟合风险:避免在预训练阶段过拟合
  5. 评估方法:需要多种评估方式验证预训练效果
  6. 资源管理:合理使用混合精度和梯度检查点节省显存

未来发展趋势

新兴自监督任务

class EmergingSelfSupervisedTasks:
    """新兴自监督任务"""
    
    def __init__(self):
        pass
    
    def token_deletion_prediction(self, model, text, deletion_ratio=0.1):
        """Token删除预测任务"""
        words = text.split()
        n_delete = max(1, int(len(words) * deletion_ratio))
        
        # 随机删除tokens
        delete_indices = random.sample(range(len(words)), n_delete)
        deleted_words = [words[i] for i in delete_indices]
        
        # 创建删除后的文本
        remaining_words = [word for i, word in enumerate(words) if i not in delete_indices]
        deleted_text = ' '.join(remaining_words)
        
        # 预测被删除的tokens
        prediction_prompt = f"原文本删除了一些词,请预测被删除的词:\n删除后:{deleted_text}\n被删除的词:"
        
        predicted_words = model.generate(prediction_prompt, max_length=50)
        
        return {
            'original': text,
            'deleted': deleted_text,
            'deleted_words': deleted_words,
            'predicted_words': predicted_words
        }
    
    def sentence_order_prediction(self, model, paragraphs):
        """句子顺序预测任务"""
        shuffled_tasks = []
        
        for paragraph in paragraphs:
            sentences = paragraph.split('。')
            sentences = [s.strip() + '。' for s in sentences if s.strip()]
            
            if len(sentences) > 2:
                # 打乱顺序
                original_order = list(range(len(sentences)))
                shuffled_order = original_order.copy()
                random.shuffle(shuffled_order)
                
                shuffled_sentences = [sentences[i] for i in shuffled_order]
                shuffled_text = ' '.join(shuffled_sentences)
                
                shuffled_tasks.append({
                    'shuffled_text': shuffled_text,
                    'correct_order': original_order,
                    'shuffled_order': shuffled_order
                })
        
        return shuffled_tasks
    
    def next_sentence_prediction_v2(self, model, document_pairs):
        """改进的下一句预测任务"""
        tasks = []
        
        for doc1, doc2 in document_pairs:
            sentences1 = doc1.split('。')
            sentences2 = doc2.split('。')
            
            # 正样本:连续句子
            if len(sentences1) > 1:
                pos_sent1 = sentences1[0]
                pos_sent2 = sentences1[1]
                tasks.append({
                    'sentence1': pos_sent1,
                    'sentence2': pos_sent2,
                    'label': 1  # 连续
                })
            
            # 负样本:不同文档的句子
            if sentences2:
                neg_sent1 = sentences1[0] if sentences1 else ""
                neg_sent2 = random.choice(sentences2)
                tasks.append({
                    'sentence1': neg_sent1,
                    'sentence2': neg_sent2,
                    'label': 0  # 不连续
                })
        
        return tasks
    
    def entity_linking_prediction(self, model, texts_with_entities):
        """实体链接预测任务"""
        linking_tasks = []
        
        for item in texts_with_entities:
            text = item['text']
            entities = item['entities']
            
            for entity in entities:
                # 掩码实体
                masked_text = text.replace(entity['mention'], '[ENTITY]')
                
                # 预测实体类型或描述
                linking_tasks.append({
                    'masked_text': masked_text,
                    'entity_mention': entity['mention'],
                    'entity_type': entity['type'],
                    'entity_description': entity.get('description', '')
                })
        
        return linking_tasks

可解释性增强

class InterpretableSelfSupervised:
    """可解释的自监督学习"""
    
    def __init__(self, model):
        self.model = model
        
    def attention_pattern_analysis(self, text):
        """注意力模式分析"""
        inputs = self.tokenizer(text, return_tensors='pt')
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
            attentions = outputs.attentions
        
        # 分析注意力模式
        attention_patterns = []
        
        for layer_idx, layer_attention in enumerate(attentions):
            # layer_attention: [batch, heads, seq_len, seq_len]
            layer_patterns = {}
            
            # 平均所有头的注意力
            avg_attention = layer_attention.mean(dim=1).squeeze(0)  # [seq_len, seq_len]
            
            # 分析自注意力模式
            self_attention = torch.diag(avg_attention)
            cross_attention = avg_attention - torch.diag(torch.diag(avg_attention))
            
            layer_patterns['self_attention'] = self_attention.tolist()
            layer_patterns['cross_attention'] = cross_attention.tolist()
            layer_patterns['max_attention_position'] = torch.argmax(avg_attention, dim=1).tolist()
            
            attention_patterns.append(layer_patterns)
        
        return attention_patterns
    
    def gradient_based_interpretation(self, text, target_token_idx):
        """基于梯度的解释"""
        inputs = self.tokenizer(text, return_tensors='pt')
        inputs['input_ids'].requires_grad_(True)
        
        # 前向传播
        outputs = self.model(**inputs)
        
        # 获取目标token的logit
        target_logit = outputs.logits[0, target_token_idx, :]
        target_score = torch.max(target_logit)
        
        # 反向传播计算梯度
        target_score.backward()
        
        # 获取输入梯度
        input_gradients = inputs['input_ids'].grad.abs()
        
        # 计算重要性分数
        importance_scores = input_gradients.squeeze(0).tolist()
        
        # 获取tokens
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze(0))
        
        token_importance = [
            {'token': token, 'importance': score}
            for token, score in zip(tokens, importance_scores)
        ]
        
        return token_importance
    
    def probing_representations(self, texts, probing_tasks):
        """表示探测"""
        probing_results = {}
        
        for task_name, task_data in probing_tasks.items():
            print(f"Probing for {task_name}...")
            
            # 提取所有层的表示
            layer_representations = {i: [] for i in range(self.model.config.num_hidden_layers)}
            labels = []
            
            for text, label in zip(task_data['texts'], task_data['labels']):
                inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
                
                with torch.no_grad():
                    outputs = self.model(**inputs, output_hidden_states=True)
                    hidden_states = outputs.hidden_states
                
                # 每一层的CLS表示
                for layer_idx, layer_hidden in enumerate(hidden_states[1:]):  # 跳过输入embeddings
                    cls_repr = layer_hidden[:, 0, :].squeeze().numpy()
                    layer_representations[layer_idx].append(cls_repr)
                
                labels.append(label)
            
            # 对每一层训练分类器
            layer_performance = {}
            
            for layer_idx in range(self.model.config.num_hidden_layers):
                X = np.array(layer_representations[layer_idx])
                y = np.array(labels)
                
                # 划分训练测试集
                from sklearn.model_selection import train_test_split
                from sklearn.linear_model import LogisticRegression
                from sklearn.metrics import accuracy_score
                
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.2, random_state=42
                )
                
                # 训练分类器
                clf = LogisticRegression(max_iter=1000)
                clf.fit(X_train, y_train)
                
                # 评估
                y_pred = clf.predict(X_test)
                accuracy = accuracy_score(y_test, y_pred)
                
                layer_performance[layer_idx] = accuracy
            
            probing_results[task_name] = layer_performance
        
        return probing_results

相关概念

延伸阅读

推荐资源最后更新:2024年12月