Documentation Index
Fetch the complete documentation index at: https://docs.apiyi.com/llms.txt
Use this file to discover all available pages before exploring further.
概念定义
注意力机制(Attention Mechanism)是一种让模型动态聚焦于输入序列中相关部分的技术,通过计算注意力权重来决定不同位置信息的重要程度,是现代深度学习特别是Transformer架构的核心组件。
详细解释
什么是注意力机制?
注意力机制的灵感来自人类的视觉注意力——我们在观察场景时会选择性地关注某些区域而忽略其他部分。在深度学习中,这一机制使模型能够动态地为输入的不同部分分配不同的”注意力”权重。
核心思想
- 选择性聚焦:确定在特定上下文中哪些元素最重要
- 动态权重:根据当前任务自适应调整关注点
- 全局视野:可以直接建立长距离依赖关系
- 并行计算:摆脱了RNN的顺序限制
发展历程
- 2014年:首次应用于机器翻译(Bahdanau注意力)
- 2017年:Transformer提出自注意力机制
- 2022年:Flash Attention优化计算效率
- 2024年:Flash Attention 3和DCFormer等新进展
形象比喻想象你在阅读一篇文章:
- 传统RNN:像逐字阅读,容易忘记开头内容
- 注意力机制:像快速浏览全文,同时关注多个重要部分
- 多头注意力:像多个专家同时阅读,各自关注不同方面
注意力机制让模型拥有了”一目十行”的能力,可以同时理解文本的全局关系。
数学原理
注意力计算公式
Attention(Q, K, V) = softmax(QK^T / √d_k)V
其中:
- Q(Query):查询向量,代表当前关注的位置
- K(Key):键向量,代表被比较的位置
- V(Value):值向量,代表实际信息内容
- d_k:键向量的维度,用于缩放防止梯度消失
核心类型
自注意力(Self-Attention)
自注意力让序列中的每个位置都能关注到序列中的所有其他位置,建立全局依赖关系。
import torch
import torch.nn.functional as F
class SelfAttention(torch.nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# 三个线性变换矩阵
self.W_q = torch.nn.Linear(embed_dim, embed_dim)
self.W_k = torch.nn.Linear(embed_dim, embed_dim)
self.W_v = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 计算Q、K、V
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x)
V = self.W_v(x)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_dim ** 0.5)
# 应用softmax获得注意力权重
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
多头注意力(Multi-Head Attention)
多头注意力通过并行运行多个注意力头,让模型能够从不同角度理解输入信息。
class MultiHeadAttention(torch.nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 确保embed_dim可以被num_heads整除
assert embed_dim % num_heads == 0
self.W_q = torch.nn.Linear(embed_dim, embed_dim)
self.W_k = torch.nn.Linear(embed_dim, embed_dim)
self.W_v = torch.nn.Linear(embed_dim, embed_dim)
self.W_o = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 线性变换并分割成多头
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# 转置以便于批量计算
Q = Q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# 合并多头
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.embed_dim
)
# 最终线性变换
output = self.W_o(context)
return output
多头注意力的优势
- 多角度理解:不同的头可以关注不同类型的信息(语法、语义、位置等)
- 并行计算:多个头可以并行处理,提高效率
- 表达能力强:相比单头注意力有更强的建模能力
- 稳定性好:某个头失效不会严重影响整体性能
最新进展(2024)
Flash Attention 3
2024年发布的Flash Attention 3专门针对NVIDIA H100 GPU优化,带来了革命性的性能提升:
关键创新
- 硬件特定优化:充分利用Hopper架构的异步特性
- 操作重叠:计算和数据移动并行进行
- FP8精度支持:使用Hadamard变换处理异常值
- 性能提升:达到230 TFLOPs/s,是Flash Attention的2倍
内存优化
# Flash Attention将内存复杂度从O(N²)降到O(N)
# 传统注意力
memory_traditional = seq_length ** 2 # 二次方增长
# Flash Attention
memory_flash = seq_length # 线性增长
# 对于64K上下文窗口
# 传统:~16GB内存
# Flash:~256MB内存
动态组合多头注意力(DCMHA)是2024年的重要突破:
核心改进
- 动态组合:注意力头之间可以动态交互
- 参数效率:不增加参数量的情况下提升性能
- 即插即用:可直接替换标准MHA模块
- 性能提升:计算性能提升高达2倍
class DCMHAttention(torch.nn.Module):
"""动态组合多头注意力(简化示例)"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.mha = MultiHeadAttention(embed_dim, num_heads)
self.dynamic_weight = torch.nn.Parameter(
torch.ones(num_heads, num_heads) / num_heads
)
def forward(self, x):
# 标准多头注意力
attn_output = self.mha(x)
# 动态组合不同头的输出
# 实际实现更复杂,这里仅作示意
combined_output = self.apply_dynamic_weights(attn_output)
return combined_output
实际应用
机器翻译中的注意力
class TranslationAttention:
"""展示注意力在翻译中的作用"""
def visualize_attention(self, source_text, target_text, attention_weights):
"""可视化源语言和目标语言之间的注意力关系"""
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
attention_weights,
xticklabels=source_text.split(),
yticklabels=target_text.split(),
cmap='Blues',
ax=ax
)
ax.set_xlabel('源语言')
ax.set_ylabel('目标语言')
ax.set_title('翻译注意力权重可视化')
return fig
文本生成中的因果注意力
def create_causal_mask(seq_len):
"""创建因果注意力掩码,防止看到未来信息"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # True表示可以关注,False表示屏蔽
# 使用示例
seq_len = 5
mask = create_causal_mask(seq_len)
print(mask)
# tensor([[ True, False, False, False, False],
# [ True, True, False, False, False],
# [ True, True, True, False, False],
# [ True, True, True, True, False],
# [ True, True, True, True, True]])
长文本处理优化
class SlidingWindowAttention:
"""滑动窗口注意力,用于处理超长文本"""
def __init__(self, window_size=512, stride=256):
self.window_size = window_size
self.stride = stride
def process_long_text(self, text_embedding, model):
"""分窗口处理长文本"""
total_len = text_embedding.shape[1]
outputs = []
for start in range(0, total_len - self.window_size + 1, self.stride):
end = start + self.window_size
window = text_embedding[:, start:end, :]
# 处理当前窗口
window_output = model(window)
outputs.append(window_output)
# 合并窗口结果(这里需要处理重叠部分)
return self.merge_windows(outputs)
性能优化技巧
注意力计算优化
1. 使用Flash Attention
# 标准注意力
def standard_attention(Q, K, V):
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
# Flash Attention (伪代码)
def flash_attention(Q, K, V):
# 分块计算,避免存储完整的注意力矩阵
# 实际使用需要专门的CUDA kernel
return flash_attn_func(Q, K, V)
2. 稀疏注意力模式
class SparseAttention:
"""稀疏注意力,减少计算复杂度"""
def __init__(self, sparsity_pattern='local'):
self.pattern = sparsity_pattern
def create_sparse_mask(self, seq_len):
if self.pattern == 'local':
# 只关注局部窗口
window_size = 128
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2)
mask[i, start:end] = 1
return mask
注意事项
- 内存消耗:标准注意力的内存需求是O(N²),长序列需要特别注意
- 数值稳定性:使用缩放因子√d_k防止softmax饱和
- 位置信息:自注意力本身不包含位置信息,需要额外的位置编码
- 计算精度:FP16/BF16训练时要注意数值溢出问题
变体与扩展
交叉注意力(Cross-Attention)
用于编码器-解码器架构,让解码器关注编码器的输出:
class CrossAttention(torch.nn.Module):
"""解码器中的交叉注意力"""
def forward(self, decoder_input, encoder_output):
# Q来自解码器,K和V来自编码器
Q = self.W_q(decoder_input)
K = self.W_k(encoder_output)
V = self.W_v(encoder_output)
return self.attention(Q, K, V)
相对位置注意力
考虑相对位置信息的注意力机制:
class RelativePositionAttention:
"""T5等模型使用的相对位置注意力"""
def __init__(self, max_distance=128):
self.max_distance = max_distance
self.rel_pos_bias = torch.nn.Embedding(
2 * max_distance + 1,
num_heads
)
def get_relative_position(self, seq_len):
"""计算相对位置矩阵"""
positions = torch.arange(seq_len)
rel_pos = positions[:, None] - positions[None, :]
# 裁剪到最大距离
rel_pos = rel_pos.clamp(-self.max_distance, self.max_distance)
# 转换为正数索引
rel_pos = rel_pos + self.max_distance
return rel_pos
相关概念
延伸阅读