Why Clip
在前两篇中,我们推导了:
- REINFORCE
- Baseline / Advantage
它们共享一个核心假设:
数据必须来自当前策略(on-policy)
这一篇我们换一个角度:
能不能复用旧数据?
On-policy 的限制
回顾之前的形式:
\[\nabla_{\theta}\mathcal{L}(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \sum_{t=0}^{T-1} \left[ \hat{A}_t\ \nabla_{\theta} \log \pi_{\theta}(a_t \mid s_t) \right]\]其中:
\[\hat{A}_t = R(\tau_{\ge t}) - b(s_t)\]问题是:
- 数据必须来自当前 policy
- 一旦更新,数据立刻“过期”
- ❌ 每次更新都要重新采样
- ❌ 样本利用率低
- ❌ 在真实环境中成本很高
能不能用旧数据?
假设我们有一批轨迹:
\[\tau \sim \pi_{\text{old}}\]但我们想优化的是:
\[\pi_\theta\]用 importance sampling 改写
我们可以写:
\[\mathbb{E}_{\tau \sim \pi_\theta} = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \left[ \frac{\pi_\theta(\tau)}{\pi_{\text{old}}(\tau)} \right]\]由于 trajectory 概率可以分解:
\[\frac{\pi_\theta(\tau)}{\pi_{\text{old}}(\tau)} = \prod_{t=0}^{T-1} \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\text{old}}(a_t \mid s_t)}\]为了降低方差,我们通常使用 逐步(per-step)近似:
定义:
\[r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\text{old}}(a_t \mid s_t)}\]得到新的目标
\[\mathcal{L}(\theta) = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \sum_{t=0}^{T-1} \left[ r_t(\theta)\ \hat{A}_t \right]\]但是 ratio 会爆炸,方差太大了
- 如果新策略更偏向某个 action \(r_t > 1 \\)
- 如果不偏向 \(r_t < 1 \\)
这会导致:
- ❌ 梯度极不稳定
- ❌ 方差巨大
- ❌ policy collapse
PPO
不要完全相信这个 ratio,把它限制住
公式
\[\mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \sum_{t=0}^{T-1} \left[ \min\left( r_t(\theta)\hat{A}_t,\ \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t \right) \right]\]我们分两种情况看。
情况 1:(\hat{A}_t > 0)(好动作)
我们希望:
- 增大该动作概率
- 即 ( r_t > 1 )
但:
- 如果 ( r_t \gg 1 ),更新会过于激进
- PPO 允许增加,但最多到 (1 + \epsilon)
- 超过之后不再继续增大
情况 2:(\hat{A}_t < 0)(坏动作)
我们希望:
- 降低该动作概率
- 即 ( r_t < 1 )
但:
- 如果 ( r_t \ll 1 ),下降过猛
- PPO 允许下降,但最低到 (1 - \epsilon)
- 再低就停止惩罚
直觉总结
PPO 限制了每一步 policy 更新的幅度 PPO 不是在修正梯度
而是在控制“更新步长”
虽然我们使用了旧数据:
- 只使用最近一批数据
- 只更新有限次(multiple epochs)
所以:
PPO 更像是一个 带有 off-policy 形式的 on-policy 方法
实际使用的目标函数
工程上我们使用:
\[\mathcal{L}(\theta) = \mathcal{L}^{\text{CLIP}}(\theta) - c_1 \mathcal{L}^{\text{value}} + c_2 \mathcal{H}(\pi_\theta)\]其中:
- value loss:拟合 (V(s_t))
- entropy:鼓励探索
代码(PyTorch)
import torch
def ppo_loss(old_log_prob, new_log_prob, values, returns, eps, mask):
"""
old_log_prob: [batch_size, seq_len]
旧策略下,已采样动作的 log π_old(a_t | s_t)
new_log_prob: [batch_size, seq_len]
新策略下,已采样动作的 log π_θ(a_t | s_t)
values: [batch_size, seq_len]
critic 输出的 V_φ(s_t)
returns: [batch_size, seq_len]
由采样轨迹计算得到的 G_t
eps: float
PPO clipping 系数,通常取 0.1 / 0.2
mask: [batch_size, seq_len]
有效位置 mask
"""
advantage = returns - values
advantage_detached = advantage.detach()
ratio = torch.exp(new_log_prob - old_log_prob)
surr1 = ratio * advantage_detached
surr2 = torch.clamp(ratio, 1 - eps, 1 + eps) * advantage_detached
actor_loss = -(torch.min(surr1, surr2) * mask).sum() / mask.sum()
critic_loss = (((values - returns.detach()) ** 2) * mask).sum() / mask.sum()
return actor_loss, critic_loss