形式化建模:RL 是当前策略分布上的 reward-tilting 与 KL 正则投影#
本文沿用前文的有限时域自回归建模记号:状态 s=(x,a<t) 表示“prompt + 当前前缀”,动作 a 表示 token,dπ(s) 表示策略 π 的状态占用测度。这里重点讨论 RL 目标在这些记号下到底约束了什么分布。
4. RL:当前策略分布上的 reward-tilting 与 KL 正则投影#
RL 的形式很多。为了得到可严格证明的分布关系,本文采用两种互补的表述:
- 轨迹层面的 KL 正则化 reward 最大化;
- 状态层面的 KL 正则化局部策略改进。
这两种写法都能得到精确的“reward-tilted target distribution”结论。PPO、GRPO 等工程算法可视为对这类局部 / 全局正则目标的数值近似,而不是完全不同的数学对象。
4.1 策略梯度恒等式的逐步推导#
先回顾最基础的策略梯度。设轨迹总 reward 为
R(τ)=t=1∑Hr(st,at).
定义目标
J(θ):=Eτ∼Pπθ[R(τ)].
定理 4.1(策略梯度公式). 有
∇θJ(θ)=Eτ∼Pπθ[R(τ)t=1∑H∇θlogπθ(at∣st)].
进一步,若 Qπθ(s,a) 表示在状态 s 采取动作 a 的回报期望,则
∇θJ(θ)=s∑dπθ(s)a∑πθ(a∣s)Qπθ(s,a)∇θlogπθ(a∣s).
若再减去任意 baseline b(s),则
∇θJ(θ)=s∑dπθ(s)a∑πθ(a∣s)Aπθ(s,a)∇θlogπθ(a∣s),
其中
Aπθ(s,a)=Qπθ(s,a)−b(s).
证明. 第一步,用 score-function trick:
∇θJ(θ)=∇θτ∑Pπθ(τ)R(τ)=τ∑∇θPπθ(τ)R(τ)=τ∑Pπθ(τ)∇θlogPπθ(τ)R(τ).
而
Pπθ(τ)=ρ(x)t=1∏Hπθ(at∣st),
于是
logPπθ(τ)=logρ(x)+t=1∑Hlogπθ(at∣st).
因为 ρ 不依赖 θ,
∇θlogPπθ(τ)=t=1∑H∇θlogπθ(at∣st).
代回即得
∇θJ(θ)=Eτ∼Pπθ[R(τ)t=1∑H∇θlogπθ(at∣st)].
第二步,把 R(τ) 替换为从当前步往后的 return Gt,得到
∇θJ(θ)=E[t=1∑HGt∇θlogπθ(at∣st)].
对 (st,at) 条件化:
∇θJ(θ)=t=1∑Hs,a∑πθPr(st=s,at=a)E[Gt∣st=s,at=a]∇θlogπθ(a∣s)=t=1∑Hs,a∑πθPr(st=s)πθ(a∣s)Qπθ(s,a)∇θlogπθ(a∣s).
对 t 求和即得
∇θJ(θ)=s∑dπθ(s)a∑πθ(a∣s)Qπθ(s,a)∇θlogπθ(a∣s).
第三步,说明 baseline 不改变梯度:
a∑πθ(a∣s)b(s)∇θlogπθ(a∣s)=b(s)a∑πθ(a∣s)πθ(a∣s)∇θπθ(a∣s)=b(s)a∑∇θπθ(a∣s)=b(s)∇θa∑πθ(a∣s)=b(s)∇θ1=0.
因此可以把 Qπθ(s,a) 替换为 Aπθ(s,a)=Qπθ(s,a)−b(s)。
注记 4.2. 这里最重要的一点不是公式本身,而是梯度中的状态权重是 dπθ,也就是当前策略真正访问到的状态分布。这正是 RL 的 on-policy 性质。
4.2 轨迹层面的 KL 正则化 RL:reward-tilted Gibbs 分布#
定义 4.3(参考条件轨迹分布). 给定参考策略 πref,定义在每个 prompt x 下的参考条件轨迹分布
Pref(y∣x)=t=1∏Hπref(at∣st),
其中 y=(a1,…,aH)。
定义 4.4(轨迹层面的 KL 正则化 RL 目标). 设 R(x,y) 是每个 prompt 和完整回答的序列级 reward,β>0 为正则强度。定义
Jtraj(P):=x∈X∑ρ(x)y∈Y∑P(y∣x)R(x,y)−βx∈X∑ρ(x)DKL(P(⋅∣x)∥Pref(⋅∣x)),
其中优化变量是条件分布 P(⋅∣x)∈Δ(Y)。
定理 4.5(轨迹层面的 Gibbs 最优解). 对每个 prompt x,最优条件分布为
P⋆(y∣x)=Zβ(x)Pref(y∣x)exp(R(x,y)/β),
其中
Zβ(x):=y′∑Pref(y′∣x)exp(R(x,y′)/β).
并且有精确恒等式
Jtraj(P)=βx∑ρ(x)logZβ(x)−βx∑ρ(x)DKL(P(⋅∣x)∥P⋆(⋅∣x)).
证明. 对每个固定的 x,问题可分解为
P(⋅∣x)∈Δ(Y)max{y∑P(y∣x)R(x,y)−βy∑P(y∣x)logPref(y∣x)P(y∣x)}.
加入归一化约束 ∑yP(y∣x)=1 的 Lagrange 乘子 λx:
Lx=y∑P(y∣x)R(x,y)−βy∑P(y∣x)logPref(y∣x)P(y∣x)+λx(y∑P(y∣x)−1).
对每个 y 求偏导并令其为零:
0=∂P(y∣x)∂Lx=R(x,y)−β(logPref(y∣x)P(y∣x)+1)+λx.
整理得
logPref(y∣x)P(y∣x)=βR(x,y)+λx−β.
指数化后
P(y∣x)=Pref(y∣x)exp(βR(x,y))⋅exp(βλx−β).
后面的项与 y 无关,因此是归一化常数,记为 1/Zβ(x),得到
P⋆(y∣x)=Zβ(x)Pref(y∣x)eR(x,y)/β.
下面证明第二个恒等式。由 P⋆ 的定义,
logP⋆(y∣x)=logPref(y∣x)+βR(x,y)−logZβ(x).
因此
DKL(P(⋅∣x)∥P⋆(⋅∣x))=y∑P(y∣x)logP⋆(y∣x)P(y∣x)=y∑P(y∣x)[logP(y∣x)−logPref(y∣x)−βR(x,y)+logZβ(x)]=DKL(P(⋅∣x)∥Pref(⋅∣x))−β1y∑P(y∣x)R(x,y)+logZβ(x).
移项后得到
y∑P(y∣x)R(x,y)−βDKL(P(⋅∣x)∥Pref(⋅∣x))=βlogZβ(x)−βDKL(P(⋅∣x)∥P⋆(⋅∣x)).
最后乘上 ρ(x) 并对 x 求和即得结论。
推论 4.6(RL 在轨迹层面是对 reward-tilted target 的反向 KL 投影). 定理 4.5 说明:带 KL 正则的轨迹级 RL,不是“凭空把概率推到高 reward”,而是把参考轨迹分布 Pref 经过 eR/β 做 Gibbs 变换,再对这个新分布做反向 KL 投影。
4.3 状态层面的 KL 正则局部策略改进:Boltzmann teacher#
轨迹公式很干净,但工程里更常见的是当前策略 πk 的局部改进步。这时我们用 advantage 写局部目标。
定义 4.7(KL 正则的局部策略改进目标). 给定当前策略 πk 及其 advantage 函数 Ak(s,a),定义
Ik(π):=s∈S∑dπk(s)[a∈A∑π(a∣s)Ak(s,a)−βDKL(π(⋅∣s)∥πk(⋅∣s))].
定义 4.8(Boltzmann 改进算子). 对每个状态 s,定义
Bβ[πk,Ak](a∣s):=∑b∈Aπk(b∣s)exp(Ak(s,b)/β)πk(a∣s)exp(Ak(s,a)/β).
为了简记,后文记
πB,k(⋅∣s):=Bβ[πk,Ak](⋅∣s).
定理 4.9(KL 正则局部 RL 等价于对 Boltzmann teacher 的反向 KL 投影). 对任意固定状态 s,定义
Zk(s):=b∈A∑πk(b∣s)exp(Ak(s,b)/β).
则有精确恒等式
a∑π(a∣s)Ak(s,a)−βDKL(π(⋅∣s)∥πk(⋅∣s))=βlogZk(s)−βDKL(π(⋅∣s)∥πB,k(⋅∣s)).
因此
Ik(π)=s∑dπk(s)βlogZk(s)−βs∑dπk(s)DKL(π(⋅∣s)∥πB,k(⋅∣s)).
故其最优解满足
argπmaxIk(π)={π:π(⋅∣s)=πB,k(⋅∣s),∀s∈supp(dπk)}.
证明. 固定状态 s,先计算
DKL(π(⋅∣s)∥πB,k(⋅∣s))=a∑π(a∣s)logπB,k(a∣s)π(a∣s)=a∑π(a∣s)logπk(a∣s)eAk(s,a)/β/Zk(s)π(a∣s)=a∑π(a∣s)logπk(a∣s)π(a∣s)−β1a∑π(a∣s)Ak(s,a)+logZk(s)=DKL(π(⋅∣s)∥πk(⋅∣s))−β1a∑π(a∣s)Ak(s,a)+logZk(s).
移项可得
a∑π(a∣s)Ak(s,a)−βDKL(π∥πk)=βlogZk(s)−βDKL(π∥πB,k).
对所有状态乘上 dπk(s) 再求和即得第二式。
最后,由 KL 非负性知,每个状态的最优值在且仅在 π(⋅∣s)=πB,k(⋅∣s) 时取得。
推论 4.10(带 KL 正则的 RL 是 on-policy 的 reward-shaped self-distillation). 局部 RL 改进可以严格理解为:在当前策略自己的状态分布 dπk 上,把策略投影到一个由 πk 与 Ak 共同诱导出的 Boltzmann teacher πB,k。
注记 4.11. 这条结论非常关键。它表明“RL 与蒸馏完全无关”并不准确。在 KL 正则视角下,RL 的确可以写成一种 teacher matching,只是这个 teacher 不是外部模型,而是由当前策略和 reward / advantage 共同定义出来的目标分布。
4.4 本节具体性质#
命题 4.12(Boltzmann teacher 的赔率比性质). 对任意固定状态 s 和任意两个动作 a,b,若 πk(a∣s),πk(b∣s)>0,则
πB,k(b∣s)πB,k(a∣s)=πk(b∣s)πk(a∣s)exp(βAk(s,a)−Ak(s,b)).
也即 Boltzmann 改进在 log-odds 上等于“原 log-odds + advantage 差 / 温度”。
证明. 由定义
πB,k(a∣s)=Zk(s)πk(a∣s)eAk(s,a)/β,πB,k(b∣s)=Zk(s)πk(b∣s)eAk(s,b)/β.
两式相除即可:
πB,k(b∣s)πB,k(a∣s)=πk(b∣s)πk(a∣s)e(Ak(s,a)−Ak(s,b))/β.
命题 4.13(Boltzmann teacher 的温度极限). 固定状态 s,记优势最大动作集合
Ms:=arga∈AmaxAk(s,a).
则有
β→∞limπB,k(⋅∣s)=πk(⋅∣s),
以及对任意动作 a,
β→0+limπB,k(a∣s)=∑b∈Msπk(b∣s)πk(a∣s)1{a∈Ms}.
证明. 当 β→∞ 时,对每个动作都有
eAk(s,a)/β→1,
故
πB,k(a∣s)=∑bπk(b∣s)eAk(s,b)/βπk(a∣s)eAk(s,a)/β→∑bπk(b∣s)πk(a∣s)=πk(a∣s).
当 β→0+ 时,设
Amax(s)=bmaxAk(s,b).
把分子分母同时除以 eAmax(s)/β:
πB,k(a∣s)=∑bπk(b∣s)e(Ak(s,b)−Amax(s))/βπk(a∣s)e(Ak(s,a)−Amax(s))/β.
若 a∈/Ms,则指数项趋于 0;若 a∈Ms,则指数项趋于 1。于是极限正是题述结果。
推论 4.14(精确改进差距等于加权反向 KL). 在定理 4.9 的条件下,
Ik(πB,k)−Ik(π)=βs∈S∑dπk(s)DKL(π(⋅∣s)∥πB,k(⋅∣s)).
证明. 把定理 4.9 中的恒等式分别代入 π=πB,k 与一般的 π。由于
DKL(πB,k∥πB,k)=0,
相减即可得到上述等式。
注记 4.15. 这三条具体性质说明:RL 的 teacher 并不是抽象地“偏向高 reward”,而是以可计算的赔率比、温度极限、以及精确 improvement gap 形式出现。
附录 A:最小代码验证#
A.1 验证 Boltzmann teacher 与改进差距#
下面的例子验证定理 4.9 和推论 4.14:KL 正则局部 RL 目标等于对 Boltzmann teacher 的反向 KL 投影,且改进差距正好是加权反向 KL。
import torch
torch.manual_seed(0)
num_states = 3
num_actions = 5
beta = 0.7
pi_k = torch.softmax(torch.randn(num_states, num_actions), dim=-1)
advantage = torch.randn(num_states, num_actions)
d_pi_k = torch.tensor([3.0, 2.0, 0.5])
candidate = torch.softmax(torch.randn(num_states, num_actions), dim=-1)
weights = pi_k * torch.exp(advantage / beta)
z_k = weights.sum(dim=-1)
pi_b = weights / z_k[:, None]
def kl(p, q):
return (p * (torch.log(p) - torch.log(q))).sum(dim=-1)
def local_objective(pi):
reward_term = (pi * advantage).sum(dim=-1)
regularizer = beta * kl(pi, pi_k)
return (d_pi_k * (reward_term - regularizer)).sum()
lhs = local_objective(candidate)
rhs = (d_pi_k * (beta * torch.log(z_k) - beta * kl(candidate, pi_b))).sum()
teacher_value = local_objective(pi_b)
gap = teacher_value - lhs
expected_gap = beta * (d_pi_k * kl(candidate, pi_b)).sum()
print(torch.allclose(lhs, rhs, atol=1e-6))
print(torch.allclose(gap, expected_gap, atol=1e-6))
print(pi_b)
python
这里的 pi_b 就是由当前策略 pi_k 和 advantage 共同诱导出的 Boltzmann teacher。