深入理解Transformer的核心技术——注意力机制,掌握自注意力和多头注意力的原理与应用
Attention(Q, K, V) = softmax(QK^T / √d_k)V
import torch import torch.nn.functional as F class SelfAttention(torch.nn.Module): def __init__(self, embed_dim): super().__init__() self.embed_dim = embed_dim # 三个线性变换矩阵 self.W_q = torch.nn.Linear(embed_dim, embed_dim) self.W_k = torch.nn.Linear(embed_dim, embed_dim) self.W_v = torch.nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, _ = x.shape # 计算Q、K、V Q = self.W_q(x) # (batch, seq_len, embed_dim) K = self.W_k(x) V = self.W_v(x) # 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_dim ** 0.5) # 应用softmax获得注意力权重 attn_weights = F.softmax(scores, dim=-1) # 加权求和 output = torch.matmul(attn_weights, V) return output, attn_weights
class MultiHeadAttention(torch.nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # 确保embed_dim可以被num_heads整除 assert embed_dim % num_heads == 0 self.W_q = torch.nn.Linear(embed_dim, embed_dim) self.W_k = torch.nn.Linear(embed_dim, embed_dim) self.W_v = torch.nn.Linear(embed_dim, embed_dim) self.W_o = torch.nn.Linear(embed_dim, embed_dim) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape # 线性变换并分割成多头 Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim) K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim) V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim) # 转置以便于批量计算 Q = Q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim) K = K.transpose(1, 2) V = V.transpose(1, 2) # 计算注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) if mask is not None: scores.masked_fill_(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) context = torch.matmul(attn_weights, V) # 合并多头 context = context.transpose(1, 2).contiguous().view( batch_size, seq_len, self.embed_dim ) # 最终线性变换 output = self.W_o(context) return output
# Flash Attention将内存复杂度从O(N²)降到O(N) # 传统注意力 memory_traditional = seq_length ** 2 # 二次方增长 # Flash Attention memory_flash = seq_length # 线性增长 # 对于64K上下文窗口 # 传统:~16GB内存 # Flash:~256MB内存
class DCMHAttention(torch.nn.Module): """动态组合多头注意力(简化示例)""" def __init__(self, embed_dim, num_heads): super().__init__() self.mha = MultiHeadAttention(embed_dim, num_heads) self.dynamic_weight = torch.nn.Parameter( torch.ones(num_heads, num_heads) / num_heads ) def forward(self, x): # 标准多头注意力 attn_output = self.mha(x) # 动态组合不同头的输出 # 实际实现更复杂,这里仅作示意 combined_output = self.apply_dynamic_weights(attn_output) return combined_output
class TranslationAttention: """展示注意力在翻译中的作用""" def visualize_attention(self, source_text, target_text, attention_weights): """可视化源语言和目标语言之间的注意力关系""" import matplotlib.pyplot as plt import seaborn as sns fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap( attention_weights, xticklabels=source_text.split(), yticklabels=target_text.split(), cmap='Blues', ax=ax ) ax.set_xlabel('源语言') ax.set_ylabel('目标语言') ax.set_title('翻译注意力权重可视化') return fig
def create_causal_mask(seq_len): """创建因果注意力掩码,防止看到未来信息""" mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) return mask == 0 # True表示可以关注,False表示屏蔽 # 使用示例 seq_len = 5 mask = create_causal_mask(seq_len) print(mask) # tensor([[ True, False, False, False, False], # [ True, True, False, False, False], # [ True, True, True, False, False], # [ True, True, True, True, False], # [ True, True, True, True, True]])
class SlidingWindowAttention: """滑动窗口注意力,用于处理超长文本""" def __init__(self, window_size=512, stride=256): self.window_size = window_size self.stride = stride def process_long_text(self, text_embedding, model): """分窗口处理长文本""" total_len = text_embedding.shape[1] outputs = [] for start in range(0, total_len - self.window_size + 1, self.stride): end = start + self.window_size window = text_embedding[:, start:end, :] # 处理当前窗口 window_output = model(window) outputs.append(window_output) # 合并窗口结果(这里需要处理重叠部分) return self.merge_windows(outputs)
# 标准注意力 def standard_attention(Q, K, V): scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) weights = F.softmax(scores, dim=-1) return torch.matmul(weights, V) # Flash Attention (伪代码) def flash_attention(Q, K, V): # 分块计算,避免存储完整的注意力矩阵 # 实际使用需要专门的CUDA kernel return flash_attn_func(Q, K, V)
class SparseAttention: """稀疏注意力,减少计算复杂度""" def __init__(self, sparsity_pattern='local'): self.pattern = sparsity_pattern def create_sparse_mask(self, seq_len): if self.pattern == 'local': # 只关注局部窗口 window_size = 128 mask = torch.zeros(seq_len, seq_len) for i in range(seq_len): start = max(0, i - window_size // 2) end = min(seq_len, i + window_size // 2) mask[i, start:end] = 1 return mask
class CrossAttention(torch.nn.Module): """解码器中的交叉注意力""" def forward(self, decoder_input, encoder_output): # Q来自解码器,K和V来自编码器 Q = self.W_q(decoder_input) K = self.W_k(encoder_output) V = self.W_v(encoder_output) return self.attention(Q, K, V)
class RelativePositionAttention: """T5等模型使用的相对位置注意力""" def __init__(self, max_distance=128): self.max_distance = max_distance self.rel_pos_bias = torch.nn.Embedding( 2 * max_distance + 1, num_heads ) def get_relative_position(self, seq_len): """计算相对位置矩阵""" positions = torch.arange(seq_len) rel_pos = positions[:, None] - positions[None, :] # 裁剪到最大距离 rel_pos = rel_pos.clamp(-self.max_distance, self.max_distance) # 转换为正数索引 rel_pos = rel_pos + self.max_distance return rel_pos