手撕代码
记录一些大模型算法中常常会考察的手撕代码。
MHA
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
多头注意力模块(无dropout,支持mask)
"""
def __init__(self, d_model, num_heads):
"""
初始化参数和线性层
Args:
d_model: 模型总特征维度
num_heads: 注意力头数
"""
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
# 每个头的维度,必须能整除
self.head_dim = d_model // num_heads
# 四个线性层:分别用于生成Q、K、V和最终输出
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
"""
前向传播
Args:
query: 查询张量 [batch, q_len, d_model]
key: 键张量 [batch, k_len, d_model]
value: 值张量 [batch, k_len, d_model] (通常key和value同源)
mask: 注意力掩码 [batch, 1, q_len, k_len] 或广播兼容形状,为0的位置被屏蔽
Returns:
output: [batch, q_len, d_model]
"""
batch = query.size(0)
# ========== 1. 线性投影 ==========
# 将输入映射到 d_model 维空间(未分头)
Q = self.Wq(query) # [batch, q_len, d_model]
K = self.Wk(key) # [batch, k_len, d_model]
V = self.Wv(value) # [batch, k_len, d_model]
# ========== 2. 分头(reshape + transpose) ==========
# view: [batch, seq_len, d_model] -> [batch, seq_len, num_heads, head_dim]
# transpose(1,2): -> [batch, num_heads, seq_len, head_dim]
# 作用:将“头”的维度提前,让每个头可以独立并行计算注意力
Q = Q.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 现在形状: [batch, num_heads, seq_len, head_dim]
# ========== 3. 缩放点积注意力 ==========
# 计算 Q·K^T / sqrt(d_k)
# torch.matmul: 对最后两维做矩阵乘法
# Q: [batch, heads, q_len, head_dim]
# K.transpose(-2,-1): [batch, heads, head_dim, k_len]
# scores: [batch, heads, q_len, k_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 应用掩码:将 mask==0 的位置设为极小值(-1e9),softmax后概率接近0
if mask is not None:
# mask 形状通常为 [batch, 1, q_len, k_len],会自动广播到 heads 维度
scores = scores.masked_fill(mask == 0, -1e9)
# softmax 在最后一个维度(k_len)上归一化,得到注意力权重
attn = F.softmax(scores, dim=-1) # [batch, heads, q_len, k_len]
# 加权求和:attn · V
# attn: [batch, heads, q_len, k_len]
# V: [batch, heads, k_len, head_dim]
# context: [batch, heads, q_len, head_dim]
context = torch.matmul(attn, V)
# ========== 4. 合并多头 ==========
# 先转置回 [batch, q_len, heads, head_dim]
# contiguous() 保证内存连续,之后 view 才能正确工作
# view 合并最后两维: [batch, q_len, heads*head_dim] = [batch, q_len, d_model]
context = context.transpose(1, 2)
context = context.contiguous()
context = context.view(batch, -1, self.d_model)
# ========== 5. 输出投影 ==========
# 通过线性层融合多头信息,得到最终输出
output = self.Wo(context) # [batch, q_len, d_model]
return output
SFT Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class SFTLoss(nn.Module):
"""
监督微调损失模块
只计算 response 部分(即模型需要生成的答案部分)的损失,
忽略 prompt 部分(用户输入的问题部分)。
"""
def forward(self, logits, labels, prompt_lengths):
"""
计算 SFT 损失
Args:
logits: 模型输出的原始分数(未经过 softmax)
形状: [batch_size, seq_len, vocab_size]
- batch_size: 一次处理多少个独立样本
- seq_len: 整个序列长度(prompt + response)
- vocab_size: 词表大小
labels: 真实的 token 索引
形状: [batch_size, seq_len]
每个位置的数值范围 0 到 vocab_size-1
prompt_lengths: 每个样本中 prompt 部分的长度
形状: [batch_size],例如 [3, 5] 表示第一个样本前3个token是prompt
Returns:
loss: 标量(0维张量),平均交叉熵损失
"""
# ========== 1. 复制 labels,避免修改原始张量 ==========
# labels.clone() 创建一个深拷贝,后续的修改不会影响原 labels
masked_labels = labels.clone() # 形状: [batch_size, seq_len]
# ========== 2. 构造 mask 矩阵,标记哪些位置属于 prompt ==========
batch_size, seq_len = labels.shape # 例如 batch_size=2, seq_len=10
# torch.arange(seq_len) 生成一个一维张量: [0, 1, 2, ..., seq_len-1]
# 例如 seq_len=10 → [0,1,2,3,4,5,6,7,8,9]
# 然后 .to(labels.device) 确保张量在同一个设备上(CPU或GPU)
indices = torch.arange(seq_len, device=labels.device) # 形状: [seq_len]
# prompt_lengths 形状: [batch_size],例如 [3, 5]
# unsqueeze(1) 在维度1插入一个维度,形状变为 [batch_size, 1]
# 然后和 indices [seq_len] 进行广播比较:
# 对于每个 batch 样本 i,比较 indices 中的每个位置 j 是否小于 prompt_lengths[i]
# 结果是一个布尔张量 mask,形状 [batch_size, seq_len]
# 例如:prompt_lengths[0]=3,则 mask[0,0]=True, mask[0,1]=True, mask[0,2]=True, 之后都是 False
mask = indices < prompt_lengths.unsqueeze(1) # 形状: [batch_size, seq_len]
# 说明:< 运算符会逐元素比较,返回布尔值(True/False)
# ========== 3. 将 prompt 位置的 label 设为 -100 ==========
# PyTorch 的交叉熵损失函数 F.cross_entropy 有一个参数 ignore_index
# 当 label 等于 ignore_index 时,该位置的损失会被忽略(不计入总损失)
# 这里我们将所有属于 prompt 的 token 标签设为 -100(默认的 ignore_index)
# masked_labels[mask] 表示从 masked_labels 中选择所有 mask==True 的位置
# 将这些位置的值赋值为 -100
masked_labels[mask] = -100 # 形状不变: [batch_size, seq_len]
# 注意:response 部分的 label 保持不变(仍是真实 token 索引)
# ========== 4. 移位操作 (Shift) ==========
# 语言模型的训练方式是:给定前 t 个 token,预测第 t+1 个 token
# 因此我们需要将 logits 和 labels 进行错位:
# - logits 去掉最后一个 token(因为没有下一个 token 作为目标)
# - labels 去掉第一个 token(因为第一个 token 没有前一个输入来预测它)
# logits[:, :-1, :] 表示取所有 batch,序列维度上从索引0到倒数第二个(不包括最后一个)
# 例如 seq_len=10,则取索引 0~8,共9个位置
# .contiguous() 是让内存连续,因为切片操作可能产生不连续的张量,后续 view 需要连续内存
shift_logits = logits[:, :-1, :].contiguous() # 形状: [batch_size, seq_len-1, vocab_size]
# masked_labels[:, 1:] 表示取所有 batch,序列维度上从索引1到最后一个
# 例如 seq_len=10,取索引1~9,共9个位置
# 这样 shift_logits 的第 i 个位置对应预测 labels 的第 i+1 个位置
shift_labels = masked_labels[:, 1:].contiguous() # 形状: [batch_size, seq_len-1]
# ========== 5. 展平并计算交叉熵损失 ==========
# F.cross_entropy 要求输入形状为 (N, C) 和 (N,),其中 N 是样本总数,C 是类别数
# 所以需要将 [batch_size, seq_len-1, vocab_size] 展平成 [batch_size*(seq_len-1), vocab_size]
# 将 [batch_size, seq_len-1] 展平成 [batch_size*(seq_len-1)]
# shift_logits.size(-1) 获取最后一个维度的大小,即 vocab_size
# shift_logits.view(-1, vocab_size) 将前两个维度合并为一维,保留 vocab_size 维
flattened_logits = shift_logits.view(-1, shift_logits.size(-1)) # 形状: [N, vocab_size]
# shift_labels.view(-1) 将二维展平为一维
flattened_labels = shift_labels.view(-1) # 形状: [N]
# 调用交叉熵损失函数
# 参数 ignore_index=-100 使得所有值为 -100 的标签不参与损失计算(即被忽略)
# F.cross_entropy 内部会先对 logits 做 log_softmax,然后计算负对数似然
loss = F.cross_entropy(flattened_logits, flattened_labels, ignore_index=-100)
# 返回标量损失值
return loss