Jerry's Blog

Back

形式化建模:SFT 是外部数据分布上的前向 KL 投影#

1. 阅读指南#

1.1 本节具体性质#

为了避免后文定理读起来像“只是在换符号”,先把本文想强调的四个性质直接写出来:

  1. 可识别性性质:任何后训练目标都只能识别其加权到的 states;未被状态权重测度看到的 states,不可能由该目标单独约束。
  2. 三元组决定性:一旦某方法能写成“状态分布 ν\nu + 目标分布 ξs\xi_s + 散度 DD”,那么它在 supp(ν)\operatorname{supp}(\nu) 上的理想最优行为就由这三个对象完全决定,算法名字只影响数值求解路径。
  3. 遗忘的必要条件:若旧能力主要依赖于 supp(ν)\operatorname{supp}(\nu) 之外的状态,那么任何只在 ν\nu 上训练的方法,都无法从目标函数本身推出 retention 保证。
  4. 泛化的必要条件:要把训练误差转成部署误差,必须额外比较训练状态分布与部署状态分布;只看 token loss、reward 大小或“是否有 teacher”本身都不够。

这四条并不是新的定理,而是全文后续结果的阅读指南:后面的所有证明,本质上都在把这四句从直觉改写成精确公式。

2. 形式化建模#

2.1 把自回归语言模型写成有限时域决策过程#

定义 2.1(提示、动作、状态).X\mathcal X 为 prompt 集合,A\mathcal A 为 token 动作集合,包含特殊终止符 EOS。设最大生成步数为 HNH\in\mathbb N

在第 tt 步,状态记为

st=(x,a<t)=(x,a1,,at1),s_t=(x,a_{<t})=(x,a_1,\ldots,a_{t-1}),

其中 xXx\in\mathcal XaiAa_i\in\mathcal A。因此状态就是“prompt + 当前前缀”。

注记 2.2. 为了避免可变长序列引入不必要的技术细节,本文采用“最大长度 HH,遇到 EOS 后进入吸收终止态”的标准处理。这样可以把每条轨迹都看成长度恰为 HH 的序列,而不影响终止前的自回归部分。

定义 2.3(策略). 一个语言模型策略 π\pi 是从状态到动作分布的映射:

π(s)Δ(A),sS.\pi(\cdot\mid s)\in\Delta(\mathcal A),\qquad s\in\mathcal S.

其中 Δ(A)\Delta(\mathcal A) 表示动作单纯形。若用 logits zπ(s,a)z_\pi(s,a) 参数化,则

π(as)=exp(zπ(s,a))bAexp(zπ(s,b)).\pi(a\mid s) = \frac{\exp(z_\pi(s,a))} {\sum_{b\in\mathcal A}\exp(z_\pi(s,b))}.

定义 2.4(轨迹分布). 设 prompt 分布为 ρ(x)\rho(x)。给定策略 π\pi,一条轨迹记为

τ=(x,a1,,aH).\tau=(x,a_1,\ldots,a_H).

其条件概率为

Pπ(τ)=ρ(x)t=1Hπ(atst),st=(x,a<t).P_\pi(\tau) = \rho(x) \prod_{t=1}^{H} \pi(a_t\mid s_t), \qquad s_t=(x,a_{<t}).

有限时域自回归生成可以看成从 prompt 出发、沿 token 前缀树逐步选择动作,直到最大步数或 EOS。

图 1:有限时域决策过程的直观图。左侧是 prompt,右侧是由 token 动作展开的前缀树;橙色路径表示一条具体轨迹,灰色分支表示同一状态下其他可能动作。

定义 2.5(状态-动作边缘、占用测度). 定义策略 π\pi 的状态-动作边缘为

qπ(s,a):=τPπ(τ)t=1H1{st=s, at=a}.q_\pi(s,a) := \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s,\ a_t=a\}.

定义对应的状态占用测度

dπ(s):=aAqπ(s,a)=τPπ(τ)t=1H1{st=s}.d_\pi(s) := \sum_{a\in\mathcal A}q_\pi(s,a) = \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s\}.

定义归一化占用分布

dˉπ(s):=1Hdπ(s)=1Ht=1HPrπ(st=s).\bar d_\pi(s) := \frac{1}{H}d_\pi(s) = \frac{1}{H} \sum_{t=1}^{H} \Pr_\pi(s_t=s).

注记 2.6. dπd_\pi 不是概率分布,它的总质量是 HHdˉπ\bar d_\pi 才是概率分布。本文凡是讨论总变差(TV)或分布失配时,都优先使用 dˉπ\bar d_\pi

引理 2.7(对轨迹求和与对占用测度求和是等价的). 对任意函数 f:S×ARf:\mathcal S\times\mathcal A\to\mathbb R

EτPπ[t=1Hf(st,at)]=sSaAqπ(s,a)f(s,a).\mathbb E_{\tau\sim P_\pi} \left[ \sum_{t=1}^{H} f(s_t,a_t) \right] = \sum_{s\in\mathcal S} \sum_{a\in\mathcal A} q_\pi(s,a)f(s,a).

g:SRg:\mathcal S\to\mathbb R 只依赖状态,则

EτPπ[t=1Hg(st)]=sSdπ(s)g(s).\mathbb E_{\tau\sim P_\pi} \left[ \sum_{t=1}^{H} g(s_t) \right] = \sum_{s\in\mathcal S} d_\pi(s)g(s).

证明. 直接展开:

EτPπ[t=1Hf(st,at)]=τPπ(τ)t=1Hf(st,at)=τPπ(τ)t=1Hs,a1{st=s, at=a}f(s,a)=s,a(τPπ(τ)t=1H1{st=s, at=a})f(s,a)=s,aqπ(s,a)f(s,a).\begin{aligned} \mathbb E_{\tau\sim P_\pi} \left[ \sum_{t=1}^{H} f(s_t,a_t) \right] &= \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} f(s_t,a_t) \\ &= \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \sum_{s,a} \mathbf 1\{s_t=s,\ a_t=a\}f(s,a) \\ &= \sum_{s,a} \left( \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s,\ a_t=a\} \right) f(s,a) \\ &= \sum_{s,a} q_\pi(s,a)f(s,a). \end{aligned}

第二式令 f(s,a)=g(s)f(s,a)=g(s) 即得。

2.2 外部数据分布与经验数据策略#

定义 2.8(数据轨迹分布).QQ 是一个外部数据分布,例如人工标注答案、teacher 生成答案、混合示范集等。它给出轨迹 τ=(x,a1,,aH)\tau=(x,a_1,\ldots,a_H) 的概率 Q(τ)Q(\tau)

与策略分布完全同理,定义数据的状态-动作边缘

qQ(s,a):=τQ(τ)t=1H1{st=s, at=a},dQ(s):=aqQ(s,a),q_Q(s,a) := \sum_\tau Q(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s,\ a_t=a\}, \qquad d_Q(s) := \sum_a q_Q(s,a), dˉQ(s):=1HdQ(s).\bar d_Q(s) := \frac{1}{H} d_Q(s).

定义 2.9(经验数据策略). 对所有满足 dQ(s)>0d_Q(s)>0 的状态,定义经验数据策略

μQ(as):=qQ(s,a)dQ(s).\mu_Q(a\mid s) := \frac{q_Q(s,a)} {d_Q(s)}.

dQ(s)=0d_Q(s)=0,则 μQ(s)\mu_Q(\cdot\mid s) 可任意指定,因为该状态在所有相关目标中权重为 00

注记 2.10. 当每个 (s)(s) 只出现一个标注 token 时,μQ(s)\mu_Q(\cdot\mid s) 就是 one-hot 分布;当有多参考答案、软标签、或 teacher logits 时,μQ(s)\mu_Q(\cdot\mid s) 就是一般分布。因此“硬标签 SFT”和“软标签蒸馏”在数学上只差一个 target 分布是否为 one-hot。

2.3 本文用到的散度与算子#

定义 2.11(交叉熵、前向 KL、反向 KL、总变差). 对定义在 A\mathcal A 上的两个分布 p,qp,q

交叉熵定义为

H(p,q):=aAp(a)logq(a).H(p,q) := - \sum_{a\in\mathcal A} p(a)\log q(a).

前向 KL 定义为

DKL(pq):=aAp(a)logp(a)q(a).D_{\mathrm{KL}}(p\Vert q) := \sum_{a\in\mathcal A} p(a)\log\frac{p(a)}{q(a)}.

反向 KL 在符号上就是交换顺序:

DKL(qp):=aAq(a)logq(a)p(a).D_{\mathrm{KL}}(q\Vert p) := \sum_{a\in\mathcal A} q(a)\log\frac{q(a)}{p(a)}.

总变差定义为

TV(p,q):=12aAp(a)q(a).\mathrm{TV}(p,q) := \frac{1}{2} \sum_{a\in\mathcal A} |p(a)-q(a)|.

注记 2.12. 注意:工程里常说“forward KL / reverse KL”,本质上只是两个参数顺序不同;本文会明确写成 DKL(pq)D_{\mathrm{KL}}(p\Vert q)DKL(qp)D_{\mathrm{KL}}(q\Vert p),避免口头混淆。

2.4 本节具体性质#

命题 2.13(占用测度的归一化与因子分解). 对任意策略 π\pi

sSdπ(s)=H,sSaAqπ(s,a)=H,qπ(s,a)=dπ(s)π(as).\sum_{s\in\mathcal S} d_\pi(s) = H, \qquad \sum_{s\in\mathcal S} \sum_{a\in\mathcal A} q_\pi(s,a) = H, \qquad q_\pi(s,a) = d_\pi(s)\pi(a\mid s).

同理,对任意外部数据分布 QQ

sdQ(s)=H,s,aqQ(s,a)=H,qQ(s,a)=dQ(s)μQ(as).\sum_s d_Q(s) = H, \qquad \sum_{s,a} q_Q(s,a) = H, \qquad q_Q(s,a) = d_Q(s)\mu_Q(a\mid s).

因此 dˉπ\bar d_\pidˉQ\bar d_Q 都是概率分布:

sdˉπ(s)=1,sdˉQ(s)=1.\sum_s \bar d_\pi(s)=1, \qquad \sum_s \bar d_Q(s)=1.

证明. 先证策略情形。由定义

dπ(s)=τPπ(τ)t=1H1{st=s}.d_\pi(s) = \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s\}.

ss 求和:

sdπ(s)=sτPπ(τ)t=1H1{st=s}=τPπ(τ)t=1Hs1{st=s}=τPπ(τ)t=1H1=HτPπ(τ)=H.\begin{aligned} \sum_s d_\pi(s) &= \sum_s \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s\} \\ &= \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \sum_s \mathbf 1\{s_t=s\} \\ &= \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} 1 = H \sum_\tau P_\pi(\tau) = H. \end{aligned}

同理,

s,aqπ(s,a)=τPπ(τ)t=1Hs,a1{st=s, at=a}=H.\sum_{s,a} q_\pi(s,a) = \sum_\tau P_\pi(\tau) \sum_{t=1}^{H} \sum_{s,a} \mathbf 1\{s_t=s,\ a_t=a\} = H.

再证因子分解。因为在给定状态 st=ss_t=s 时,下一动作由 π(s)\pi(\cdot\mid s) 采样,

Prπ(st=s, at=a)=Prπ(st=s)π(as).\Pr_\pi(s_t=s,\ a_t=a) = \Pr_\pi(s_t=s)\pi(a\mid s).

对所有可能的时间步求和即可得

qπ(s,a)=dπ(s)π(as).q_\pi(s,a) = d_\pi(s)\pi(a\mid s).

数据分布 QQ 的情形完全同理;只是把 π(as)\pi(a\mid s) 替换成经验条件分布 μQ(as)\mu_Q(a\mid s)

命题 2.14(前缀概率的显式形式与支持传播). 设某个状态写成

s=(x,a<t)=(x,a1,,at1).s=(x,a_{<t})=(x,a_1,\ldots,a_{t-1}).

则该状态在第 tt 步被访问到的概率满足

Prπ(st=s)=ρ(x)i=1t1π(aix,a<i).\Pr_\pi(s_t=s) = \rho(x) \prod_{i=1}^{t-1} \pi(a_i\mid x,a_{<i}).

因此,若 Prπ(st=s)>0\Pr_\pi(s_t=s)>0π(ats)>0\pi(a_t\mid s)>0,则后继状态

s=(x,at)=(x,a1,,at)s'=(x,a_{\le t})=(x,a_1,\ldots,a_t)

也满足

Prπ(st+1=s)>0.\Pr_\pi(s_{t+1}=s')>0.

证明. 按轨迹分布的定义,状态 st=(x,a<t)s_t=(x,a_{<t}) 在第 tt 步出现,当且仅当 prompt 等于 xx 且前 t1t-1 个 token 恰为 a1,,at1a_1,\ldots,a_{t-1}。于是

Prπ(st=s)=ρ(x)i=1t1π(aix,a<i).\Pr_\pi(s_t=s) = \rho(x) \prod_{i=1}^{t-1} \pi(a_i\mid x,a_{<i}).

进一步,

Prπ(st+1=s)=ρ(x)i=1tπ(aix,a<i)=Prπ(st=s)π(ats).\Pr_\pi(s_{t+1}=s') = \rho(x) \prod_{i=1}^{t} \pi(a_i\mid x,a_{<i}) = \Pr_\pi(s_t=s)\pi(a_t\mid s).

若右侧两因子都为正,则结论成立。

注记 2.15. 命题 2.13 与命题 2.14 给出了两个以后会反复用到的“基础性质”:一是所有目标都可以写成对占用测度的加权和;二是状态支持集会沿着正概率前缀向后传播。

3. SFT:外部数据分布上的前向 KL 投影#

3.1 目标函数展开#

定义 3.1(SFT 目标). 给定外部数据分布 QQ,SFT 目标写成

LSFT(π):=EτQ[t=1Hlogπ(atst)].L_{\mathrm{SFT}}(\pi) := \mathbb E_{\tau\sim Q} \left[ - \sum_{t=1}^{H} \log \pi(a_t\mid s_t) \right].

定理 3.2(SFT 的精确分解). SFT 目标可以精确写成

LSFT(π)=sSdQ(s)H(μQ(s),π(s)),L_{\mathrm{SFT}}(\pi) = \sum_{s\in\mathcal S} d_Q(s) H\left( \mu_Q(\cdot\mid s), \pi(\cdot\mid s) \right),

以及

LSFT(π)=CQ+sSdQ(s)DKL(μQ(s)π(s)),L_{\mathrm{SFT}}(\pi) = C_Q + \sum_{s\in\mathcal S} d_Q(s) D_{\mathrm{KL}} \left( \mu_Q(\cdot\mid s) \Vert \pi(\cdot\mid s) \right),

其中常数项

CQ:=sSdQ(s)H(μQ(s))C_Q := \sum_{s\in\mathcal S} d_Q(s) H\left( \mu_Q(\cdot\mid s) \right)

π\pi 无关。

证明. 逐步展开:

LSFT(π)=EτQ[t=1Hlogπ(atst)]=τQ(τ)t=1Hlogπ(atst)=τQ(τ)t=1Hs,a1{st=s, at=a}logπ(as)=s,a(τQ(τ)t=1H1{st=s, at=a})logπ(as)=s,aqQ(s,a)logπ(as).\begin{aligned} L_{\mathrm{SFT}}(\pi) &= \mathbb E_{\tau\sim Q} \left[ - \sum_{t=1}^{H} \log \pi(a_t\mid s_t) \right] \\ &= - \sum_\tau Q(\tau) \sum_{t=1}^{H} \log \pi(a_t\mid s_t) \\ &= - \sum_\tau Q(\tau) \sum_{t=1}^{H} \sum_{s,a} \mathbf 1\{s_t=s,\ a_t=a\} \log \pi(a\mid s) \\ &= - \sum_{s,a} \left( \sum_\tau Q(\tau) \sum_{t=1}^{H} \mathbf 1\{s_t=s,\ a_t=a\} \right) \log \pi(a\mid s) \\ &= - \sum_{s,a} q_Q(s,a)\log \pi(a\mid s). \end{aligned}

利用 qQ(s,a)=dQ(s)μQ(as)q_Q(s,a)=d_Q(s)\mu_Q(a\mid s)

LSFT(π)=sdQ(s)aμQ(as)logπ(as)=sdQ(s)H(μQ,π).\begin{aligned} L_{\mathrm{SFT}}(\pi) &= - \sum_s d_Q(s) \sum_a \mu_Q(a\mid s)\log \pi(a\mid s) \\ &= \sum_s d_Q(s) H(\mu_Q,\pi). \end{aligned}

再利用恒等式

H(μQ,π)=H(μQ)+DKL(μQπ),H(\mu_Q,\pi) = H(\mu_Q) + D_{\mathrm{KL}}(\mu_Q\Vert\pi),

得到

LSFT(π)=sdQ(s)H(μQ(s))+sdQ(s)DKL(μQ(s)π(s)).L_{\mathrm{SFT}}(\pi) = \sum_s d_Q(s) H(\mu_Q(\cdot\mid s)) + \sum_s d_Q(s) D_{\mathrm{KL}} \left( \mu_Q(\cdot\mid s) \Vert \pi(\cdot\mid s) \right).

第一项与 π\pi 无关,即为常数 CQC_Q

推论 3.3(SFT 是对经验数据策略的 off-policy distillation). 若把 μQ(s)\mu_Q(\cdot\mid s) 看成“经验 teacher”,则 SFT 正是在外部状态分布 dQd_Q 上,对 teacher 条件分布做前向 KL 投影。

证明. 由定理 3.2,优化 LSFTL_{\mathrm{SFT}} 等价于最小化

sdQ(s)DKL(μQ(s)π(s)).\sum_s d_Q(s) D_{\mathrm{KL}} \left( \mu_Q(\cdot\mid s) \Vert \pi(\cdot\mid s) \right).

这正是前向 KL distillation,只不过 teacher 不是一个显式模型,而是数据经验分布 μQ\mu_Q

SFT 在数据支持状态上,把模型条件分布投影到经验 target 分布。

图 2:SFT 的 off-policy distillation 视角。左侧是外部数据覆盖到的状态支持,中间是经验 target 分布 μQ(s)\mu_Q(\cdot\mid s),右侧是被前向 KL 拉向 target 的模型分布 π(s)\pi(\cdot\mid s)

3.2 SFT 最优解与其“只约束数据支持集”的性质#

定理 3.4(SFT 的非参数最优解). 若把优化域看成所有随机策略的集合,则

argminπLSFT(π)={π:π(s)=μQ(s)s with dQ(s)>0}.\arg\min_\pi L_{\mathrm{SFT}}(\pi) = \left\{ \pi: \pi(\cdot\mid s)=\mu_Q(\cdot\mid s) \quad \forall s\ \mathrm{with}\ d_Q(s)>0 \right\}.

在所有 dQ(s)=0d_Q(s)=0 的状态上,SFT 目标没有任何约束。

证明. 由定理 3.2,

LSFT(π)=CQ+sdQ(s)DKL(μQπ).L_{\mathrm{SFT}}(\pi) = C_Q + \sum_s d_Q(s) D_{\mathrm{KL}}(\mu_Q\Vert\pi).

每一项 KL 非负,且当且仅当 π(s)=μQ(s)\pi(\cdot\mid s)=\mu_Q(\cdot\mid s) 时取 00。对 dQ(s)=0d_Q(s)=0 的状态,对应项恒为 00,所以不约束这些状态。

命题 3.5(SFT 在数据支持集外完全不敏感). 若两个策略 π,π~\pi,\tilde\pi 满足

π(s)=π~(s)ssupp(dQ),\pi(\cdot\mid s) = \tilde\pi(\cdot\mid s) \qquad \forall s\in\operatorname{supp}(d_Q),

LSFT(π)=LSFT(π~).L_{\mathrm{SFT}}(\pi) = L_{\mathrm{SFT}}(\tilde\pi).

证明. 由定理 3.2,SFT 目标只对所有满足 dQ(s)>0d_Q(s)>0 的状态求和;在其余状态上的策略如何变化都不会进入目标函数。

注记 3.6. 这条命题几乎就是“灾难性遗忘为何可能发生”的最直接数学表述:如果旧能力对应的 states 不在当前 SFT 数据支持集里,SFT 目标对这些旧 states 没有任何显式约束。

3.3 SFT 梯度的完整展开#

固定某个状态 ss,记

pi:=π(ais),qi:=μQ(ais),pi=ezijezj.p_i := \pi(a_i\mid s), \qquad q_i := \mu_Q(a_i\mid s), \qquad p_i = \frac{e^{z_i}} {\sum_j e^{z_j}}.

对应的局部 SFT 损失为

SFT(s)=iqilogpi.\ell_{\mathrm{SFT}}(s) = - \sum_i q_i\log p_i.

引理 3.7(softmax 雅可比). 对 softmax 输出 pi=ezijezjp_i=\frac{e^{z_i}}{\sum_j e^{z_j}},有

pizk=pi(δikpk),\frac{\partial p_i} {\partial z_k} = p_i(\delta_{ik}-p_k),

其中 δik\delta_{ik} 是 Kronecker delta。

证明.Z=jezjZ=\sum_j e^{z_j},则 pi=ezi/Zp_i=e^{z_i}/Z。对 zkz_k 求导:

pizk=δikeziZeziezkZ2=eziZ(δikezkZ)=pi(δikpk).\begin{aligned} \frac{\partial p_i} {\partial z_k} &= \frac{ \delta_{ik}e^{z_i}Z - e^{z_i}e^{z_k} } {Z^2} \\ &= \frac{e^{z_i}}{Z} \left( \delta_{ik} - \frac{e^{z_k}}{Z} \right) \\ &= p_i(\delta_{ik}-p_k). \end{aligned}

定理 3.8(SFT 对 logits 的梯度). 对固定状态 ss,局部损失 SFT(s)\ell_{\mathrm{SFT}}(s) 关于 logit zkz_k 的梯度为

SFT(s)zk=pkqk.\frac{\partial \ell_{\mathrm{SFT}}(s)} {\partial z_k} = p_k-q_k.

因此全局损失满足

LSFTz(s,ak)=dQ(s)(π(aks)μQ(aks)).\frac{\partial L_{\mathrm{SFT}}} {\partial z(s,a_k)} = d_Q(s) \left( \pi(a_k\mid s) - \mu_Q(a_k\mid s) \right).

证明. 直接对

SFT(s)=iqilogpi\ell_{\mathrm{SFT}}(s) = - \sum_i q_i\log p_i

求导:

SFT(s)zk=iqi1pipizk=iqi1pipi(δikpk)=iqiδik+iqipk=qk+pkiqi=pkqk,\begin{aligned} \frac{\partial \ell_{\mathrm{SFT}}(s)} {\partial z_k} &= - \sum_i q_i \frac{1}{p_i} \frac{\partial p_i} {\partial z_k} \\ &= - \sum_i q_i \frac{1}{p_i} p_i(\delta_{ik}-p_k) \\ &= - \sum_i q_i\delta_{ik} + \sum_i q_ip_k \\ &= - q_k + p_k \sum_i q_i \\ &= p_k-q_k, \end{aligned}

因为 iqi=1\sum_i q_i=1。再乘上全局权重 dQ(s)d_Q(s) 即得。

注记 3.9. 这条公式说明 SFT 的更新方向完全由数据状态权重 dQ(s)d_Q(s) 决定。出现次数高的状态梯度大,没出现的状态梯度严格为零。

SFT logits 梯度由模型分布和数据 target 分布的差决定,并被数据状态频次加权。

图 3:SFT logits 梯度的直观图。每个状态内部比较 π(s)\pi(\cdot\mid s)μQ(s)\mu_Q(\cdot\mid s);差值决定 token 概率往上还是往下调,而左侧状态圆点大小表示 dQ(s)d_Q(s) 对梯度强度的加权。

3.4 本节具体性质#

性质 3.A(one-hot 标签是 SFT 的特例). 若某个数据状态 ss 上只有一个标注 token aa^\star,即

μQ(as)=1,μQ(as)=0(aa),\mu_Q(a^\star\mid s)=1, \qquad \mu_Q(a\mid s)=0\quad(a\neq a^\star),

则局部交叉熵退化为普通 hard-label 负对数似然:

H(μQ,π)=logπ(as).H(\mu_Q,\pi) = - \log\pi(a^\star\mid s).

对应的 logit 梯度为

SFT(s)z(s,a)=π(as)1{a=a}.\frac{\partial \ell_{\mathrm{SFT}}(s)} {\partial z(s,a)} = \pi(a\mid s) - \mathbf 1\{a=a^\star\}.

因此 hard-label SFT 不是另一个目标,而是 μQ\mu_Q 为 one-hot 时的特例。

性质 3.B(前向 KL 对数据动作的零概率敏感). 对任意满足 dQ(s)>0d_Q(s)>0 的状态,如果存在动作 aa 使得

μQ(as)>0,π(as)=0,\mu_Q(a\mid s)>0, \qquad \pi(a\mid s)=0,

DKL(μQ(s)π(s))=+.D_{\mathrm{KL}} \left( \mu_Q(\cdot\mid s) \Vert \pi(\cdot\mid s) \right) = +\infty.

这说明前向 KL 会强烈惩罚“数据里出现过的动作被模型分配零概率”。在有限 logits 的 softmax 参数化下,π(as)\pi(a\mid s) 不会真的等于 00,但当它趋近于 00 时,该项会迅速变大。

性质 3.C(SFT 不直接约束 rollout 状态分布失配).h:SRh:\mathcal S\to\mathbb R 是任意有界状态函数,且 h(s)M|h(s)|\le M。则

sdˉπ(s)h(s)sdˉQ(s)h(s)2MTV(dˉπ,dˉQ).\left| \sum_s \bar d_\pi(s)h(s) - \sum_s \bar d_Q(s)h(s) \right| \le 2M\, \mathrm{TV}(\bar d_\pi,\bar d_Q).

SFT 目标直接优化的是外部状态权重 dQ(s)d_Q(s) 下的条件分布匹配,而不是 dˉπ\bar d_\pidˉQ\bar d_Q 的接近程度。因此,当训练数据状态分布和模型 rollout 状态分布相差很大时,即使数据支持集上的 token loss 很低,模型在自己生成出来的新前缀上仍然可能缺少约束。

证明. 由总变差定义,

s(dˉπ(s)dˉQ(s))h(s)sdˉπ(s)dˉQ(s)h(s)Msdˉπ(s)dˉQ(s)=2MTV(dˉπ,dˉQ).\begin{aligned} \left| \sum_s (\bar d_\pi(s)-\bar d_Q(s))h(s) \right| &\le \sum_s |\bar d_\pi(s)-\bar d_Q(s)|\,|h(s)| \\ &\le M \sum_s |\bar d_\pi(s)-\bar d_Q(s)| \\ &= 2M\, \mathrm{TV}(\bar d_\pi,\bar d_Q). \end{aligned}

性质 3.D(logits 梯度在每个状态内质量守恒). 对固定状态 ss,局部梯度满足

kSFT(s)zk=0.\sum_k \frac{\partial \ell_{\mathrm{SFT}}(s)} {\partial z_k} = 0.

证明. 由定理 3.8,

k(pkqk)=kpkkqk=11=0.\sum_k(p_k-q_k) = \sum_k p_k - \sum_k q_k = 1-1 = 0.

这也对应 softmax 的平移不变性:给同一个状态下所有 logits 同时加上常数,不会改变策略分布。

性质 3.E(数据频次只改变梯度权重,不改变局部方向). 若两个状态 s1,s2s_1,s_2 的局部分布差相同,即

π(s1)μQ(s1)=π(s2)μQ(s2),\pi(\cdot\mid s_1)-\mu_Q(\cdot\mid s_1) = \pi(\cdot\mid s_2)-\mu_Q(\cdot\mid s_2),

则它们在全局 SFT 梯度中的比例只由 dQd_Q 决定:

LSFT/z(s1,a)LSFT/z(s2,a)=dQ(s1)dQ(s2)\frac{ \partial L_{\mathrm{SFT}}/\partial z(s_1,a) }{ \partial L_{\mathrm{SFT}}/\partial z(s_2,a) } = \frac{d_Q(s_1)}{d_Q(s_2)}

在分母非零时成立。因此,SFT 的“更重视哪些状态”完全来自数据占用频次或采样权重。

命题 3.10(SFT 的局部 Hessian 是 Fisher 矩阵). 对固定状态 ss,局部 SFT 损失

SFT(s)=iqilogpi\ell_{\mathrm{SFT}}(s) = - \sum_i q_i\log p_i

关于 logits z=(zi)iz=(z_i)_i 的 Hessian 为

2SFT(s)zkz=pk(δkp).\frac{\partial^2 \ell_{\mathrm{SFT}}(s)} {\partial z_k\partial z_\ell} = p_k(\delta_{k\ell}-p_\ell).

矩阵形式写成

z2SFT(s)=Diag(p)pp0.\nabla_z^2\ell_{\mathrm{SFT}}(s) = \operatorname{Diag}(p)-pp^\top \succeq 0.

其零空间正是常数平移方向 span{1}\operatorname{span}\{\mathbf 1\}

证明. 由定理 3.8,

SFT(s)zk=pkqk.\frac{\partial \ell_{\mathrm{SFT}}(s)} {\partial z_k} = p_k-q_k.

再对 zz_\ell 求导即可:

2SFT(s)zkz=pkz=pk(δkp),\frac{\partial^2 \ell_{\mathrm{SFT}}(s)} {\partial z_k\partial z_\ell} = \frac{\partial p_k} {\partial z_\ell} = p_k(\delta_{k\ell}-p_\ell),

其中最后一步用了引理 3.7。把所有分量拼成矩阵就是

Diag(p)pp.\operatorname{Diag}(p)-pp^\top.

对任意向量 vv

v(Diag(p)pp)v=ipivi2(ipivi)2=Varap[va]0,\begin{aligned} v^\top(\operatorname{Diag}(p)-pp^\top)v &= \sum_i p_i v_i^2 - \left( \sum_i p_i v_i \right)^2 \\ &= \operatorname{Var}_{a\sim p}[v_a] \ge 0, \end{aligned}

故该矩阵半正定。等号成立当且仅当 vav_app 的支持集上为常数;在有限 logits 的 softmax 参数化下,pp 对所有动作都有正质量,因此零空间是常数平移方向。

推论 3.11(SFT 的局部凸性与概率空间唯一最优解). 固定一个状态 ss。SFT 局部损失对 logits 是凸的;若改在概率单纯形上看,则其唯一最优解为

p=q=μQ(s).p^\star = q = \mu_Q(\cdot\mid s).

证明. 凸性由命题 3.10 的 Hessian 半正定得到。又因为

SFT(s)=H(q)+DKL(qp),\ell_{\mathrm{SFT}}(s) = H(q) + D_{\mathrm{KL}}(q\Vert p),

最小值在且仅在 DKL(qp)=0D_{\mathrm{KL}}(q\Vert p)=0 时取得,即 p=qp=q

命题 3.12(软标签与标签平滑的线性性). 设两个目标分布为 q(1),q(2)q^{(1)},q^{(2)},以及混合目标

q(λ):=(1λ)q(1)+λq(2),λ[0,1].q^{(\lambda)} := (1-\lambda)q^{(1)} + \lambda q^{(2)}, \qquad \lambda\in[0,1].

则对任意预测分布 pp

H(q(λ),p)=(1λ)H(q(1),p)+λH(q(2),p).H(q^{(\lambda)},p) = (1-\lambda)H(q^{(1)},p) + \lambda H(q^{(2)},p).

因此,软标签蒸馏、多个参考答案平均、标签平滑等,都只是 forward-KL 目标在线性空间里的不同取值。

证明. 直接展开即可:

H(q(λ),p)=a((1λ)q(1)(a)+λq(2)(a))logp(a)=(1λ)(aq(1)(a)logp(a))+λ(aq(2)(a)logp(a)).\begin{aligned} H(q^{(\lambda)},p) &= - \sum_a \left( (1-\lambda)q^{(1)}(a) + \lambda q^{(2)}(a) \right) \log p(a) \\ &= (1-\lambda) \left( - \sum_a q^{(1)}(a)\log p(a) \right) + \lambda \left( - \sum_a q^{(2)}(a)\log p(a) \right). \end{aligned}

注记 3.13. 本节最值得记住的三个具体性质是:SFT 对 logits 的二阶结构是 Fisher 矩阵;SFT 在概率空间上唯一追向数据分布本身;而只要目标分布做凸组合,SFT 损失就按同样系数线性组合。

附录 A:几个最小代码验证#

A.1 验证 SFT 分解与 logits 梯度#

下面这个 PyTorch 例子同时验证两件事:

  1. SFT soft-label 交叉熵等于常数熵项加前向 KL。
  2. logits 梯度等于 dQ(s)(πμQ)d_Q(s)(\pi-\mu_Q)
import torch

torch.manual_seed(0)

num_states = 3
num_actions = 5

logits = torch.randn(num_states, num_actions, requires_grad=True)
mu_q = torch.softmax(torch.randn(num_states, num_actions), dim=-1)
d_q = torch.tensor([4.0, 2.0, 0.0])

pi = torch.softmax(logits, dim=-1)

cross_entropy = -(mu_q * torch.log(pi)).sum(dim=-1)
loss = (d_q * cross_entropy).sum()

entropy = -(mu_q * torch.log(mu_q)).sum(dim=-1)
forward_kl = (mu_q * (torch.log(mu_q) - torch.log(pi))).sum(dim=-1)
decomposed = (d_q * (entropy + forward_kl)).sum()

print(torch.allclose(loss, decomposed, atol=1e-6))

loss.backward()
expected_grad = d_q[:, None] * (pi.detach() - mu_q)

print(torch.allclose(logits.grad, expected_grad, atol=1e-6))
print(logits.grad)
python

其中 d_q[2] = 0,所以第三个状态即使有 logits,也不会产生梯度。

A.2 hard-label SFT 是 one-hot soft-label 的特例#

import torch
import torch.nn.functional as F

torch.manual_seed(0)

batch = 4
vocab = 6

logits = torch.randn(batch, vocab)
targets = torch.tensor([2, 3, 3, 0])

hard_label_loss = F.cross_entropy(logits, targets, reduction="none")

one_hot_targets = F.one_hot(targets, num_classes=vocab).float()
soft_label_loss = -(one_hot_targets * F.log_softmax(logits, dim=-1)).sum(dim=-1)

print(hard_label_loss)
print(soft_label_loss)
print(torch.allclose(hard_label_loss, soft_label_loss, atol=1e-6))
python

这段代码对应性质 3.A:普通 token-level cross entropy 只是 μQ\mu_Q 取 one-hot 时的特殊情形。

A.3 从轨迹样本统计 qQ,dQ,μQq_Q,d_Q,\mu_Q#

下面的例子把若干条固定长度轨迹统计成状态-动作边缘和经验数据策略。

from collections import Counter, defaultdict

trajectories = [
    ("x1", ("A", "B", "EOS")),
    ("x1", ("A", "C", "EOS")),
    ("x2", ("D", "EOS", "EOS")),
]

q_q = Counter()
d_q = Counter()

for prompt, actions in trajectories:
    prefix = ()
    for action in actions:
        state = (prompt, prefix)
        q_q[(state, action)] += 1
        d_q[state] += 1
        prefix = prefix + (action,)

mu_q = defaultdict(dict)
for (state, action), count in q_q.items():
    mu_q[state][action] = count / d_q[state]

print("d_Q:")
for state, count in d_q.items():
    print(state, count)

print("\nmu_Q:")
for state, dist in mu_q.items():
    print(state, dist)
python

输出里,状态 ("x1", ("A",)) 会同时看到动作 "B""C",因此它的 μQ(s)\mu_Q(\cdot\mid s) 不是 one-hot,而是一个经验条件分布。

形式化建模:SFT 是外部数据分布上的前向 KL 投影
https://jerry609.github.io/blog/sft-forward-kl-formal-modeling
Author Jerry
Published at June 6, 2026
Comment seems to stuck. Try to refresh?✨