概念定义

自注意力(Self-Attention)是一种让序列中每个位置都能直接关注其他所有位置的机制,通过计算查询、键、值向量之间的相似度,实现信息的全局聚合和特征提取。

详细解释

自注意力机制是Transformer架构的灵魂。不同于RNN的序列依赖处理,自注意力允许模型一次性并行处理整个序列,每个位置都能直接”看到”并与其他所有位置交互。这种设计不仅解决了长距离依赖问题,还极大提升了训练效率。 2024年,Flash Attention 3的发布将自注意力的效率推向新高度。通过异步计算和FP8低精度支持,FA3在H100 GPU上实现了理论峰值FLOPs 75%的利用率,相比FA2的35%有质的飞跃。这使得Llama 3等模型能够处理百万级token的超长上下文。更重要的是,FlashDecoding技术解决了低批量推理时的GPU利用率问题,让实时应用成为可能。 自注意力的数学本质是一个加权平均过程:每个位置的输出是所有位置值向量的加权和,权重由查询-键的点积相似度决定。这种机制让模型能够动态地决定关注哪些信息,而不是依赖固定的卷积核或循环连接。从BERT的双向理解到GPT的因果生成,自注意力展现了惊人的灵活性。

工作原理

自注意力的核心步骤:
  1. 线性变换:将输入映射到查询(Q)、键(K)、值(V)空间
  2. 注意力计算:Q与K点积,缩放,Softmax归一化
  3. 值聚合:用注意力权重加权V,得到输出
  4. 并行处理:所有位置同时计算,无序列依赖

实际应用

标准自注意力实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """
    标准自注意力机制实现
    """
    def __init__(self, d_model=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Q、K、V的线性变换
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # 输出投影
        self.out_linear = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        
        # 1. 线性变换生成Q、K、V
        Q = self.q_linear(x)  # [batch, seq_len, d_model]
        K = self.k_linear(x)
        V = self.v_linear(x)
        
        # 2. 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # scores: [batch, seq_len, seq_len]
        
        # 3. 应用掩码(如果有)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 4. Softmax归一化
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 5. 加权聚合值向量
        output = torch.matmul(attn_weights, V)
        
        # 6. 输出投影
        output = self.out_linear(output)
        
        return output, attn_weights

Flash Attention 3优化实现

class FlashAttention3(nn.Module):
    """
    Flash Attention 3 with FP8 and asynchronous computation
    注:这是概念示例,实际需要CUDA kernel实现
    """
    def __init__(self, d_model=512, block_size=64, use_fp8=True):
        super().__init__()
        self.d_model = d_model
        self.block_size = block_size
        self.use_fp8 = use_fp8
        
        # QKV投影(支持FP8)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        
        if use_fp8:
            # Hadamard变换矩阵(用于离群值处理)
            self.register_buffer(
                'hadamard_matrix', 
                self.generate_hadamard_matrix(d_model)
            )
    
    def generate_hadamard_matrix(self, dim):
        """生成Hadamard矩阵用于incoherent processing"""
        def hadamard(n):
            if n == 1:
                return torch.tensor([[1.0]])
            h_n_minus_1 = hadamard(n // 2)
            top = torch.cat([h_n_minus_1, h_n_minus_1], dim=1)
            bottom = torch.cat([h_n_minus_1, -h_n_minus_1], dim=1)
            return torch.cat([top, bottom], dim=0) / math.sqrt(2)
        
        # 找到最接近的2的幂
        n = 2 ** math.ceil(math.log2(dim))
        H = hadamard(n)
        return H[:dim, :dim]
    
    def apply_incoherent_processing(self, x):
        """
        应用Hadamard变换减少量化误差
        """
        # 随机符号
        random_signs = torch.randint(0, 2, (x.shape[-1],), 
                                    device=x.device) * 2 - 1
        
        # 应用Hadamard变换
        x = x * random_signs
        x = torch.matmul(x, self.hadamard_matrix)
        
        return x, random_signs
    
    def flash_attention_kernel(self, Q, K, V, block_size):
        """
        Flash Attention核心算法(分块计算)
        """
        batch_size, seq_len, d_model = Q.shape
        
        # 初始化输出和统计量
        O = torch.zeros_like(Q)
        L = torch.zeros(batch_size, seq_len, device=Q.device)
        M = torch.full((batch_size, seq_len), -float('inf'), device=Q.device)
        
        # 分块处理
        for i in range(0, seq_len, block_size):
            Q_block = Q[:, i:i+block_size]
            
            # 重新计算统计量
            M_new = M[:, i:i+block_size].clone()
            L_new = L[:, i:i+block_size].clone()
            
            for j in range(0, seq_len, block_size):
                K_block = K[:, j:j+block_size]
                V_block = V[:, j:j+block_size]
                
                # 计算注意力分数(使用FP8如果启用)
                S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
                S_block = S_block / math.sqrt(d_model)
                
                # 更新最大值(数值稳定性)
                M_block = S_block.max(dim=-1, keepdim=True)[0]
                M_new = torch.maximum(M_new.unsqueeze(-1), M_block).squeeze(-1)
                
                # 计算exp(S - M_new)
                P_block = torch.exp(S_block - M_new.unsqueeze(-1))
                
                # 更新L(归一化因子)
                L_new = L_new * torch.exp(M[:, i:i+block_size] - M_new) + \
                        P_block.sum(dim=-1)
                
                # 更新输出
                O[:, i:i+block_size] = (
                    O[:, i:i+block_size] * 
                    torch.exp(M[:, i:i+block_size] - M_new).unsqueeze(-1) +
                    torch.matmul(P_block, V_block)
                ) / L_new.unsqueeze(-1)
                
                # 更新M
                M[:, i:i+block_size] = M_new
                L[:, i:i+block_size] = L_new
        
        return O
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # QKV投影
        qkv = self.qkv_proj(x)
        Q, K, V = qkv.chunk(3, dim=-1)
        
        if self.use_fp8:
            # 应用incoherent processing(FP8量化前)
            Q, q_signs = self.apply_incoherent_processing(Q)
            K, k_signs = self.apply_incoherent_processing(K)
            
            # 这里应该进行FP8量化(需要硬件支持)
            # Q_fp8 = quantize_to_fp8(Q)
            # K_fp8 = quantize_to_fp8(K)
        
        # Flash Attention计算
        output = self.flash_attention_kernel(Q, K, V, self.block_size)
        
        if self.use_fp8:
            # 逆Hadamard变换
            output = torch.matmul(output, self.hadamard_matrix.T)
            output = output * q_signs
        
        return output

因果自注意力(GPT风格)

class CausalSelfAttention(nn.Module):
    """
    带因果掩码的自注意力(用于自回归生成)
    """
    def __init__(self, d_model=512, max_seq_len=1024):
        super().__init__()
        self.d_model = d_model
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # 创建因果掩码
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(max_seq_len, max_seq_len))
                 .view(1, 1, max_seq_len, max_seq_len)
        )
    
    def forward(self, x, use_cache=False, past_kv=None):
        batch_size, seq_len, _ = x.shape
        
        # QKV投影
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.d_model, dim=-1)
        
        # 处理KV缓存(推理优化)
        if use_cache and past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1)
            v = torch.cat([past_v, v], dim=1)
        
        # 计算注意力
        attn_scores = torch.matmul(q, k.transpose(-2, -1))
        attn_scores = attn_scores / math.sqrt(self.d_model)
        
        # 应用因果掩码
        causal_mask = self.causal_mask[:, :, :seq_len, :k.size(1)]
        attn_scores = attn_scores.masked_fill(
            causal_mask == 0, 
            float('-inf')
        )
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = self.out_proj(output)
        
        if use_cache:
            return output, (k, v)
        return output

长序列优化

class LongContextAttention(nn.Module):
    """
    支持超长序列的注意力机制
    """
    def __init__(
        self, 
        d_model=512,
        max_seq_len=1_000_000,  # 百万级token
        use_flash_attn=True,
        use_sliding_window=False,
        window_size=4096
    ):
        super().__init__()
        self.d_model = d_model
        self.use_flash_attn = use_flash_attn
        self.use_sliding_window = use_sliding_window
        self.window_size = window_size
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        # RoPE位置编码(支持长序列)
        self.rope = RotaryEmbedding(d_model, max_seq_len)
    
    def sliding_window_attention(self, q, k, v, window_size):
        """
        滑动窗口注意力(局部注意力)
        """
        batch_size, seq_len, d_model = q.shape
        output = torch.zeros_like(q)
        
        for i in range(seq_len):
            # 计算窗口范围
            start = max(0, i - window_size // 2)
            end = min(seq_len, i + window_size // 2 + 1)
            
            # 局部注意力计算
            q_i = q[:, i:i+1]
            k_window = k[:, start:end]
            v_window = v[:, start:end]
            
            scores = torch.matmul(q_i, k_window.transpose(-2, -1))
            scores = scores / math.sqrt(d_model)
            
            attn_weights = F.softmax(scores, dim=-1)
            output[:, i] = torch.matmul(attn_weights, v_window).squeeze(1)
        
        return output
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # QKV投影
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # 应用RoPE
        q, k = self.rope(q, k)
        
        if self.use_sliding_window and seq_len > self.window_size:
            # 使用滑动窗口处理超长序列
            output = self.sliding_window_attention(q, k, v, self.window_size)
        elif self.use_flash_attn:
            # 使用Flash Attention
            from flash_attn import flash_attn_func
            output = flash_attn_func(q, k, v)
        else:
            # 标准注意力
            scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_model)
            attn_weights = F.softmax(scores, dim=-1)
            output = torch.matmul(attn_weights, v)
        
        return self.out_proj(output)

性能对比

实现方式内存复杂度时间复杂度最大序列长度GPU利用率
标准注意力O(N²)O(N²)~4K20-30%
Flash AttentionO(N)O(N²)~32K40-50%
Flash Attention 2O(N)O(N²)~128K35% (H100)
Flash Attention 3O(N)O(N²)~1M75% (H100)
滑动窗口O(N×W)O(N×W)无限制30-40%

相关概念

延伸阅读