模块 07 · 约 22 分钟 · 让模型真的"学会"

训练:从乱码到莎士比亚

架构再漂亮,参数随机初始化时也只能输出乱码。这一节讲训练: 模型怎么从一堆随机数变成会写文章的 LLM。核心就一句话 ——不断让模型猜下一个 token,猜错就调整参数

① 直觉:训练 = 反复"猜下一个词,错了就调"

想象你在教一个完全不懂中文的人写中文。你打开一本《红楼梦》,让他玩这样的游戏:

  1. 你给他看前 100 个字:「却说宝玉自下学后…」
  2. 他猜第 101 个字是什么
  3. 你告诉他正确答案,并打分(猜中"思"打满分,猜中其他字按相关性打分)
  4. 根据分数调整他的"猜测策略"(脑子里的参数)
  5. 移到下一个字,重复 —— 几百万次

训练 LLM 就是一模一样的事,只是参与者是几十亿参数的神经网络,文本量是几十 TB。 让他做同一件极其简单的任务(猜下一个 token)反复几亿次,副作用是他学会了语法、事实、推理、代码 —— 所有这些复杂能力,都是"猜下一个 token 越来越准"的涌现

② 互动:眼睁睁看 loss 下降,输出从乱码变人话

下面这个演示展示一条典型的训练 loss 曲线,并给出每个 loss 水平模型大概会输出什么样的文字:

246810loss训练步数 →
step
0 / 200
loss
10.00
perplexity
~400
loss ≈ 10.0 时模型大概会输出:
今 的 是 不 我 子 上 时 出 子

左侧 step 越大,loss 越低,模型的输出也越像人话。 loss 从 10 降到 3 是"学会基本语法和词汇",从 3 降到 2 是"学会逻辑和长程依赖"。 真实大模型训练几十万到几百万 step,曲线形状基本就是这样。

点「▶ 开始训练」,眼看 loss 从 10 降到 2。注意输出的变化: loss=10 是随机乱码;loss=4 时常用字开始出现;loss=2.5 已经能写出完整句子; loss=2.0 就开始有逻辑和上下文了。GPT-3 训练完 loss 约 1.7。

③ 损失函数:交叉熵

刚才说"给猜测打分",怎么打?答案是 交叉熵。 它衡量"模型预测的分布"和"真实答案"差多少。一个直观写法:

L=1Ni=1NlogPθ(tit<i)\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \log P_\theta(t_i \mid t_{<i})
↓ 对应的 Python 实现(可以直接改、直接运行)

跑下面这段感受 loss 的两种极端:

python
直接编辑这段代码即可。输入 np. 看自动提示,⌘/Ctrl + Enter运行。

④ 反向传播:「调整参数」具体怎么做

模型有了 loss,但 loss 是一个数字,怎么用它"调整"几亿个参数?答案是反向传播 + 梯度下降

核心是链式法则。loss 是一个关于所有参数 θ 的函数,对每个参数 θi\theta_i 求偏导:

Lθi\frac{\partial \mathcal{L}}{\partial \theta_i}
↓ 对应的 Python 实现(可以直接改、直接运行)

手算这些偏导对几亿参数显然不现实。所幸 PyTorch / JAX 等框架提供自动微分: 你只需要写前向计算(model(x) → logits → loss),框架在背后自动构建计算图, 调用 loss.backward() 时反向遍历这张图,链式法则会被精确地算出来。

所以训练 LLM 时你不需要自己写求导。你只需要确保前向是正确的, PyTorch 会替你处理"上亿个参数的梯度怎么算"。

⑤ 优化器:从 SGD 到 AdamW

有了梯度,最简单的更新法是 SGD(随机梯度下降):

θt+1=θtηL(θt)\theta_{t+1} = \theta_t - \eta \cdot \nabla \mathcal{L}(\theta_t)
↓ 对应的 Python 实现(可以直接改、直接运行)

所以现代 LLM 训练几乎都用 AdamW。它做了两件事:

  • 给每个参数维护自己的步长。Adam 维护"梯度滑动平均 m"(一阶矩)和"梯度平方滑动平均 v"(二阶矩), 更新时用 m/vm / \sqrt{v}。结果:梯度一直很大的参数自动减小步长,梯度小的参数加大步长。
  • 解耦权重衰减。"AdamW" 的 W 就是 Weight decay。 它把 L2 正则化从梯度里拆出来,直接乘到参数上:θ ← (1 - λ·lr)·θ - lr·m/√v。 这样正则化才不会被自适应步长扭曲。

实践中除了 lr,还要设 β1=0.9,β2=0.95\beta_1=0.9, \beta_2=0.95(GPT-3 这么设),ε=108\varepsilon = 10^{-8}λ=0.1\lambda = 0.1

⑥ 学习率:训练里最关键的超参

所有训练超参里,学习率(lr)最玄学也最重要:

  • 太大 → 一步把参数打飞,loss 直接 NaN
  • 太小 → 训练慢得绝望,几个月也降不下来
  • 正好 → 平稳下降

而且最优 lr 还会变 —— 训练初期梯度大,应该用小 lr;中期可以放大;后期接近收敛要再减小。 所以大模型训练一定要用 learning rate schedule(学习率调度):

η(t)={ηmaxttwt<twηmin+12(ηmaxηmin)(1+cos(πttwTtw))ttw\eta(t) = \begin{cases} \eta_{\max} \cdot \frac{t}{t_w} & t < t_w \\ \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})(1 + \cos(\pi \cdot \frac{t - t_w}{T - t_w})) & t \ge t_w \end{cases}
↓ 对应的 Python 实现(可以直接改、直接运行)

跑一下,打印每个 step 的 lr:

python
直接编辑这段代码即可。输入 np. 看自动提示,⌘/Ctrl + Enter运行。

⑦ 完整训练循环(伪代码)

把上面全部串起来,LLM 训练循环就这 7 步:

python
直接编辑这段代码即可。输入 np. 看自动提示,⌘/Ctrl + Enter运行。

⑧ 几个训练里常踩的坑

  • 梯度爆炸:某些 step 梯度突然变得很大,参数被打飞,loss → NaN。 解决:梯度裁剪 clip_grad_norm_(params, max_norm=1.0)
  • 梯度消失:层太深梯度传不回去,前几层学不动。 解决:残差连接 + 合适的初始化 + pre-norm(上一节讲过的 Transformer 设计就是为了这个)。
  • 学习率太高:loss 一开始下降然后突然炸到 NaN,或者长期在某个值附近震荡。 解决:用 warmup,从小到大;上限别超过 5e-4。
  • 过拟合:训练 loss 很低但验证 loss 高。 但大语言模型在大语料上反而很少过拟合 — 通常只过一遍数据就停了。

⑨ Scaling Laws:算力、数据、参数怎么分配

Kaplan et al. 2020 和 Chinchilla 论文(2022)发现一个有用的规律:在固定计算预算 C 下, 模型参数量 N 和训练数据量 D 应该同时按 √C 增长,最优分配大致是 D / N ≈ 20。

  • GPT-3:175B 参数,但只训了 300B token → D/N ≈ 1.7,欠训练
  • LLaMA-2 7B:训了 2T token → D/N ≈ 285,远远超额训练 → 用更少参数达到更强性能

这就是为什么 2023 年之后大家开始训练"更小但喂更多数据"的模型 — Scaling Law 改变了大家对"参数等于一切"的认知。

⑩ 小测验

Q1.LLM 训练的目标函数是什么?
Q2.为什么训练 Transformer 几乎都用 AdamW,而不是普通 SGD?

⑪ 延伸阅读