LLM 训练过程概述

LLM 训练过程概述

在介绍交叉熵损失之前,我们先参考 Build a Large Language Model 一书梳理一下训练 LLM 的核心过程。笔者并非这个方向的专业人士,只能尝试从自己理解的角度来尽可能用大白话阐述这个过程在做什么、为什么这么做、能达到什么效果。

为了便于理解,我们可以把整个过程想象成教一个学徒如何写文章

1. 文本生成(Text generation)

这就像让你的学徒开始写文章。一开始,它什么都不懂,只会胡乱地写一些词语。你给它一个开头,比如"从前有座山...",它可能随便接上"...山里有只大象在跳舞。"完全不合逻辑。

这是模型还没有训练好时,它根据一些输入,随机生成的一段文本。它生成的文本质量很差,毫无章法。

2. 文本评估(Text evaluation)

你现在需要一个"老师"来给这个学徒写的文章打分。你拿着学徒写的文章,和一篇标准答案(正确文章)进行对比。这个“老师”会告诉你,学徒写的文章和标准答案之间有多大的差距。这个打分的过程,就是我们本文将提到的交叉熵损失(Cross-Entropy Loss)

这个步骤是计算模型生成的文本与真实文本之间的损失值。模型会计算出它对下一个词的预测概率,并用交叉熵损失来衡量这个预测概率与真实词的“独热编码”概率有多大差距。损失值越大,说明模型预测得越差

3. 训练集和验证集的损失(Training set and validation set losses)

你的学徒现在开始正式学习了。你给他一大堆文章(训练集)让他模仿学习,然后定期拿出一小部分它没看过的文章(验证集)给他做测试。

  • 训练集损失: 衡量学徒在学习过程中,对那些它看过的文章模仿得有多像。
  • 验证集损失: 衡量学徒在面对新文章时,能不能把学到的东西举一反三,而不是只会死记硬背。

如果训练集损失一直下降,但验证集损失不降反升,那就说明学徒只会"死记硬背"了,这在机器学习里叫做过拟合(Overfitting)

4. 大语言模型训练函数(LLM training function)

这就是学徒的"大脑",也是整个学习的核心。它根据"老师"给出的分数(损失值),调整自己的"大脑结构"(模型参数/权重)。如果某篇文章写得不好,它就会"反思"自己为什么写不好,然后调整下一次的写作方式,争取写得更好。这个调整的过程叫做反向传播(Backpropagation)梯度下降(Gradient Descent)

5. 训练模型生成类似人类的文本(Train the model to generate human-like text)

这就是整个训练的目的:通过不断地重复第 1-4 步,让学徒的写作能力越来越强,最终写出来的文章,就像人类写的一样自然、流畅。

6. 文本生成策略(Text generation strategies)

学徒学得差不多了,但有时候会变得特别死板,只会把训练集里的东西原封不动地背出来。为了让它更有创意,更像人,你需要教它一些“写作技巧”。

例如: 有时候,你不要总是选那个最有可能出现的词,可以偶尔选一些稍微不那么确定,但也很合理的词。

这就是像Top-k 采样Top-p(核)采样温度(Temperature)调节等技术。这些方法会让模型在生成文本时,增加一些随机性,避免总是生成重复、机械化的内容,减少过拟合的风险。

7. 权重保存和加载(Weight saving & loading)

学徒经过了长期的学习,终于成才了!现在你需要把它的"大脑"状态(也就是模型参数)保存下来。这样,下次再用的时候,就不用从头开始教了,直接把这个保存好的"大脑"拿出来用就行。

8. 来自 OpenAI 的预训练权重(Pretrained weights from OpenAI)

这就像你不是从一个零基础的学徒开始教,而是直接找一个已经很有经验的"天才学徒"来培养。OpenAI 训练了海量的数据,已经把一个 GPT 模型训练得非常强大了。我们直接拿来用,再结合自己的任务,在它的基础上继续微调。这样不仅省时省力,还能得到一个更好的模型。

总结

GPT 的训练过程就是,让一个初出茅庐的学徒(模型)写文章,找一个老师(损失函数)给它打分,然后根据分数调整它的大脑(参数)。反复这个过程,直到它写出来的文章像人类一样。为了让它更有创意,我们还教它一些写作技巧。最后,我们会把它的"大脑"保存下来,或者直接用一个"天才学徒"的大脑,在上面继续学习。

交叉熵损失

接下来我们回到本文的主题:交叉熵损失(Cross-Entropy Loss)

交叉熵损失是一种衡量模型预测结果与真实结果之间差异的指标。在分类任务中,模型通常会输出一个预测概率分布,而真实标签也可以被看作一个“理想”的概率分布。交叉熵损失的作用就是比较这两个概率分布的相似程度。如果模型的预测概率分布和真实概率分布越接近,交叉熵损失就越小,反之则越大。 我们的目标就是通过训练,不断减小这个损失值,从而让模型学会做出更准确的预测。

是不是一头雾水?哈哈,没关系,下面笔者将从概念、由来、原理和计算四个部分进行展开,尽可能以大白话的方式进行阐述,相信你阅读后回来再看一段定义的时候,会有不一样的理解~

1. 概念:交叉熵损失,就是给"猜词"打分

想象一下,你正在教一个学徒写一句话。你告诉他句子的开头是:"今天天气真...",然后你让他猜下一个词应该是什么。

  • 学徒的预测: 他可能会给出一些预测,比如:

    • "好" (他觉得最可能)
    • "差" (也有一点可能)
    • "棒" (可能性更小)
    • "猫" (几乎不可能)

    这些预测,可以被看作一个概率分布。比如,他可能认为"好"的概率是 80%,"差"的概率是 15%,"棒"的概率是 4%,"猫"的概率是 1%。

  • 正确的答案: 实际上,正确的下一个词是"好"

  • 交叉熵损失的作用: 交叉熵损失就像一个严厉的老师,它只关注学徒对正确答案的预测。它会说:"你对'好'这个词的预测概率是多少?这个概率越大,你这次的表现就越好,你的'惩罚'(损失)就越小。反之,你的表现越差,你的'惩罚'就越大。"

简单来说,交叉熵损失的计算公式可以简化为: \[ 损失值 = -log(模型对正确答案的预测概率) \]

  • 如果学徒对“好”的预测概率是 0.8,那么损失值大约是 \(−log(0.8)≈0.223\)
  • 如果学徒对“好”的预测概率是 0.01(很差),那么损失值大约是 \(−log(0.01)≈4.605\)
  • 如果学徒猜中率是 1.0(完美),那么损失值是 $ −log(1)=0$。

由此可见,交叉熵损失完美地实现了我们的教学目标:预测对了,损失就小;预测错了,损失就大。

2. 由来:从信息论到机器学习的"迁移"

要理解交叉熵损失的原理,我们需要追溯到它的老家:信息论

2.1 熵(Entropy)

信息论中有一个概念叫"熵",它衡量的是一个事件的不确定性。一个越不确定的事件,它的熵就越高,包含的信息量就越大。

  • 比如,我告诉您"太阳从东边升起",这几乎是 100% 确定的事,您没有获得任何新信息,所以它的熵很低。
  • 但如果我告诉您"今天股市大涨",这本身是一个不确定的事件,您就获得了新信息,所以它的熵很高。

2.2 交叉熵(Cross-Entropy)

现在我们有两个概率分布:一个是真实的、完美的概率分布(记为 \(p\)),另一个是我们模型的预测概率分布(记为 \(q\))。

交叉熵衡量的就是,用我们模型的预测分布 \(q\) 来表示真实的分布 \(p\),需要多少额外的"信息量"或者说"代价"。

理论公式: 交叉熵的理论公式是 \(H(p,q)=−∑_ip_ilog(q_i)\)

  • 这里的 \(p_i\) 是真实事件的概率。
  • \(q_i\) 是我们模型预测的概率。

独热编码(One-hot)的简化

在机器学习的分类任务中,我们的真实标签通常是独热编码的,比如正确答案是"猫'',那么真实分布 \(p\) 就是 \[ [0, 1, 0, ...] \] 现在,让我们把独热编码的 \(p\) 代入到上面的公式中: \[ H(p,q)=−(0⋅log(q_1)+1⋅log(q_2)+0⋅log(q_3)+...) \] 你会发现,求和公式里,只有正确类别(猫)对应的 \(p_i\) 是 1,其他都是 \(0\)。所以,整个求和公式就只剩下了一项: \[ H(p,q)=−log(q_{正确类别}) \] 这就是交叉熵损失的最终形式。它之所以这样计算,完全是因为在分类任务中,我们只关心模型对正确答案的预测概率,而信息论中的交叉熵公式在遇到独热编码时,正好简化成了这个形式。

3. 原理:为什么 −log(p) 是一个好的损失函数?

让我们从数学和直觉两个角度来理解,为什么 \(−log(p)\) 是一个完美的损失函数。

3.1 数学角度

梯度: 我们的目标是通过梯度下降法来最小化损失。对于 \(−log(p)\),它的导数是 \(−1/p\)

  • \(p\) 接近 1 时(预测得很准),\(1/p\) 接近 1,损失的梯度就很小。这意味着模型参数调整的幅度不大,因为它已经做得不错了。
  • \(p\) 接近 0 时(预测得很差),\(1/p\) 趋近于无穷大,损失的梯度就变得非常大。这意味着模型参数调整的幅度会非常大,因为它犯了一个严重的错误,需要大力纠正。

这种特性使得模型在犯错时能快速学习,而在预测准确时则能稳定下来,这非常符合我们对训练过程的期望。

3.2 直觉角度

不确定性: 让我们回到信息论。\(−log(p)\) 实际上就是正确事件的信息量。

  • 如果模型预测正确事件的概率 \(p\) 很低,说明模型对正确答案非常不确定,那么这个正确答案的出现就包含了大量信息。交叉熵损失就用这个巨大的信息量来惩罚模型。
  • 如果模型预测正确事件的概率 \(p\) 很高,说明模型很确定答案,那么这个正确答案的出现就包含很少信息。交叉熵损失就用这个很小的信息量来奖励模型。

这种"用信息量来惩罚"的机制,确保了模型会努力去减少它对正确答案的不确定性,从而让它的预测结果越来越接近真实情况。

4. 计算

交叉熵损失计算过程

参考 Build a Large Language Model 一书,交叉熵损失的计算过程大概分成上面所示的 6 个步骤。

步骤 1:Logits(对数几率)

Logits 是模型在 Softmax 层之前的原始输出值,它可以是任意实数。这些值代表了模型对每个类别的"置信度",但还没有归一化为概率。图片中的 [[0.1113, -0.1057, -0.3666, ...]] 就是一个样本的 Logits 输出。

步骤 2:Probabilities(概率)

通过 Softmax 函数将 Logits 转换为概率分布。这个函数的作用是将一组任意实数转换成一个概率分布,使得所有值都在 0 到 1 之间,并且总和为 1。它的公式是 \(q_i=\frac{e^{z_i}}{∑_j^{e^{z_j}}}\), (其中 \(z_i\) 是第 \(i\) 个类别的 Logit)。[[1.8849e-05, 1.5172e-05, 1.1687e-05, ...]] 就是经过 Softmax 转换后的概率分布。

步骤 3:Target probabilities(目标概率)

这一步的核心是从模型的预测中,提取出与真实答案相对应的概率值。在理论上,我们用独热编码(One-Hot Encoding)来表示真实标签,例如 [0, 1, 0, ...]。图片中的 [7.4541e-05, ...] 正是模型根据这个独热编码所指示的正确索引,给出的预测概率。这些值通常很小,因为在训练初期,模型对正确答案的预测能力还很弱。在计算交叉熵时,我们只关心真实类别对应的预测概率。

步骤 4:Log probabilities(对数概率)

这一步是计算每个目标概率值的自然对数,即 \(log(q_i)\)。例如,[-9.5042, -10.3796, -11.3677, ...] 就是对目标概率取自然对数的结果。

步骤 5:Average log probability(平均对数概率)

这一步是计算所有对数概率的平均值。在步骤 4 中,我们已经得到了模型对每个正确答案的预测概率的对数值。这一步就是将这些值加起来,然后除以样本或序列的长度,以得到一个平均值。

步骤 6:Negative average log probability(负平均对数概率)

这是计算最终损失值的步骤。在步骤 5 的基础上,我们对平均对数概率取负号。这是为了将一个衡量模型错误程度的负数,转换成一个衡量模型错误程度的正数。这个操作没有复杂的数学含义,它只是为了让损失值的符号符合我们的直觉和约定。损失值越小代表模型表现越好。在图片中,对 -10.7940 取负号后,得到的值是 10.7940。这个值就是我们最终要最小化的损失(Loss)。在模型训练中,我们通过反向传播和梯度下降来不断减小这个损失值,从而迫使模型提高对正确答案的预测概率。

上面 6 个步骤,可以直接使用 pytorch 的 cross_entropy 计算,一步到位!

1
loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)

总结一下,整个计算流程可以概括为:

  1. 模型输出原始分数(Logits)。
  2. 通过 Softmax 函数将分数转换为概率分布。
  3. 找出真实类别对应的预测概率。
  4. 对这个概率取负对数,得到损失值。
  5. 在训练时,我们会对所有样本的损失值求平均,然后进行反向传播更新模型参数。

这个计算方式之所以合理,正是因为它完美地结合了信息论和机器学习的目标:通过最小化这个损失值,我们实际上是在最大化模型对正确类别的预测概率,从而让模型的预测分布越来越接近真实的分布。 这是一种非常高效且理论基础坚实的训练方法。