跳转至

手撕代码

记录一些大模型算法中常常会考察的手撕代码。

MHA

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """
    多头注意力模块(无dropout,支持mask)
    """
    def __init__(self, d_model, num_heads):
        """
        初始化参数和线性层

        Args:
            d_model: 模型总特征维度
            num_heads: 注意力头数
        """
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        # 每个头的维度,必须能整除
        self.head_dim = d_model // num_heads

        # 四个线性层:分别用于生成Q、K、V和最终输出
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        """
        前向传播

        Args:
            query: 查询张量 [batch, q_len, d_model]
            key:   键张量   [batch, k_len, d_model]
            value: 值张量   [batch, k_len, d_model]  (通常key和value同源)
            mask:  注意力掩码 [batch, 1, q_len, k_len] 或广播兼容形状,为0的位置被屏蔽

        Returns:
            output: [batch, q_len, d_model]
        """
        batch = query.size(0)

        # ========== 1. 线性投影 ==========
        # 将输入映射到 d_model 维空间(未分头)
        Q = self.Wq(query)   # [batch, q_len, d_model]
        K = self.Wk(key)     # [batch, k_len, d_model]
        V = self.Wv(value)   # [batch, k_len, d_model]

        # ========== 2. 分头(reshape + transpose) ==========
        # view: [batch, seq_len, d_model] -> [batch, seq_len, num_heads, head_dim]
        # transpose(1,2): -> [batch, num_heads, seq_len, head_dim]
        # 作用:将“头”的维度提前,让每个头可以独立并行计算注意力
        Q = Q.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch, -1, self.num_heads, self.head_dim).transpose(1, 2)
        # 现在形状: [batch, num_heads, seq_len, head_dim]

        # ========== 3. 缩放点积注意力 ==========
        # 计算 Q·K^T / sqrt(d_k)
        # torch.matmul: 对最后两维做矩阵乘法
        # Q: [batch, heads, q_len, head_dim]
        # K.transpose(-2,-1): [batch, heads, head_dim, k_len]
        # scores: [batch, heads, q_len, k_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 应用掩码:将 mask==0 的位置设为极小值(-1e9),softmax后概率接近0
        if mask is not None:
            # mask 形状通常为 [batch, 1, q_len, k_len],会自动广播到 heads 维度
            scores = scores.masked_fill(mask == 0, -1e9)

        # softmax 在最后一个维度(k_len)上归一化,得到注意力权重
        attn = F.softmax(scores, dim=-1)   # [batch, heads, q_len, k_len]

        # 加权求和:attn · V
        # attn: [batch, heads, q_len, k_len]
        # V:    [batch, heads, k_len, head_dim]
        # context: [batch, heads, q_len, head_dim]
        context = torch.matmul(attn, V)

        # ========== 4. 合并多头 ==========
        # 先转置回 [batch, q_len, heads, head_dim]
        # contiguous() 保证内存连续,之后 view 才能正确工作
        # view 合并最后两维: [batch, q_len, heads*head_dim] = [batch, q_len, d_model]
        context = context.transpose(1, 2)
        context = context.contiguous()
        context = context.view(batch, -1, self.d_model)

        # ========== 5. 输出投影 ==========
        # 通过线性层融合多头信息,得到最终输出
        output = self.Wo(context)   # [batch, q_len, d_model]

        return output

SFT Loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class SFTLoss(nn.Module):
    """
    监督微调损失模块
    只计算 response 部分(即模型需要生成的答案部分)的损失,
    忽略 prompt 部分(用户输入的问题部分)。
    """

    def forward(self, logits, labels, prompt_lengths):
        """
        计算 SFT 损失

        Args:
            logits: 模型输出的原始分数(未经过 softmax)
                    形状: [batch_size, seq_len, vocab_size]
                    - batch_size: 一次处理多少个独立样本
                    - seq_len: 整个序列长度(prompt + response)
                    - vocab_size: 词表大小
            labels: 真实的 token 索引
                    形状: [batch_size, seq_len]
                    每个位置的数值范围 0 到 vocab_size-1
            prompt_lengths: 每个样本中 prompt 部分的长度
                            形状: [batch_size],例如 [3, 5] 表示第一个样本前3个token是prompt

        Returns:
            loss: 标量(0维张量),平均交叉熵损失
        """
        # ========== 1. 复制 labels,避免修改原始张量 ==========
        # labels.clone() 创建一个深拷贝,后续的修改不会影响原 labels
        masked_labels = labels.clone()   # 形状: [batch_size, seq_len]

        # ========== 2. 构造 mask 矩阵,标记哪些位置属于 prompt ==========
        batch_size, seq_len = labels.shape  # 例如 batch_size=2, seq_len=10

        # torch.arange(seq_len) 生成一个一维张量: [0, 1, 2, ..., seq_len-1]
        # 例如 seq_len=10 → [0,1,2,3,4,5,6,7,8,9]
        # 然后 .to(labels.device) 确保张量在同一个设备上(CPU或GPU)
        indices = torch.arange(seq_len, device=labels.device)  # 形状: [seq_len]

        # prompt_lengths 形状: [batch_size],例如 [3, 5]
        # unsqueeze(1) 在维度1插入一个维度,形状变为 [batch_size, 1]
        # 然后和 indices [seq_len] 进行广播比较:
        # 对于每个 batch 样本 i,比较 indices 中的每个位置 j 是否小于 prompt_lengths[i]
        # 结果是一个布尔张量 mask,形状 [batch_size, seq_len]
        # 例如:prompt_lengths[0]=3,则 mask[0,0]=True, mask[0,1]=True, mask[0,2]=True, 之后都是 False
        mask = indices < prompt_lengths.unsqueeze(1)   # 形状: [batch_size, seq_len]
        # 说明:< 运算符会逐元素比较,返回布尔值(True/False)

        # ========== 3. 将 prompt 位置的 label 设为 -100 ==========
        # PyTorch 的交叉熵损失函数 F.cross_entropy 有一个参数 ignore_index
        # 当 label 等于 ignore_index 时,该位置的损失会被忽略(不计入总损失)
        # 这里我们将所有属于 prompt 的 token 标签设为 -100(默认的 ignore_index)
        # masked_labels[mask] 表示从 masked_labels 中选择所有 mask==True 的位置
        # 将这些位置的值赋值为 -100
        masked_labels[mask] = -100   # 形状不变: [batch_size, seq_len]
        # 注意:response 部分的 label 保持不变(仍是真实 token 索引)

        # ========== 4. 移位操作 (Shift) ==========
        # 语言模型的训练方式是:给定前 t 个 token,预测第 t+1 个 token
        # 因此我们需要将 logits 和 labels 进行错位:
        #   - logits 去掉最后一个 token(因为没有下一个 token 作为目标)
        #   - labels 去掉第一个 token(因为第一个 token 没有前一个输入来预测它)

        # logits[:, :-1, :] 表示取所有 batch,序列维度上从索引0到倒数第二个(不包括最后一个)
        # 例如 seq_len=10,则取索引 0~8,共9个位置
        # .contiguous() 是让内存连续,因为切片操作可能产生不连续的张量,后续 view 需要连续内存
        shift_logits = logits[:, :-1, :].contiguous()   # 形状: [batch_size, seq_len-1, vocab_size]

        # masked_labels[:, 1:] 表示取所有 batch,序列维度上从索引1到最后一个
        # 例如 seq_len=10,取索引1~9,共9个位置
        # 这样 shift_logits 的第 i 个位置对应预测 labels 的第 i+1 个位置
        shift_labels = masked_labels[:, 1:].contiguous()  # 形状: [batch_size, seq_len-1]

        # ========== 5. 展平并计算交叉熵损失 ==========
        # F.cross_entropy 要求输入形状为 (N, C) 和 (N,),其中 N 是样本总数,C 是类别数
        # 所以需要将 [batch_size, seq_len-1, vocab_size] 展平成 [batch_size*(seq_len-1), vocab_size]
        # 将 [batch_size, seq_len-1] 展平成 [batch_size*(seq_len-1)]

        # shift_logits.size(-1) 获取最后一个维度的大小,即 vocab_size
        # shift_logits.view(-1, vocab_size) 将前两个维度合并为一维,保留 vocab_size 维
        flattened_logits = shift_logits.view(-1, shift_logits.size(-1))  # 形状: [N, vocab_size]

        # shift_labels.view(-1) 将二维展平为一维
        flattened_labels = shift_labels.view(-1)  # 形状: [N]

        # 调用交叉熵损失函数
        # 参数 ignore_index=-100 使得所有值为 -100 的标签不参与损失计算(即被忽略)
        # F.cross_entropy 内部会先对 logits 做 log_softmax,然后计算负对数似然
        loss = F.cross_entropy(flattened_logits, flattened_labels, ignore_index=-100)

        # 返回标量损失值
        return loss