在前两篇中,我们推导了:

  • REINFORCE
  • Baseline / Advantage

它们共享一个核心假设:

数据必须来自当前策略(on-policy)

这一篇我们换一个角度:

能不能复用旧数据?


On-policy 的限制

回顾之前的形式:

\[\nabla_{\theta}\mathcal{L}(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \sum_{t=0}^{T-1} \left[ \hat{A}_t\ \nabla_{\theta} \log \pi_{\theta}(a_t \mid s_t) \right]\]

其中:

\[\hat{A}_t = R(\tau_{\ge t}) - b(s_t)\]

问题是:

  • 数据必须来自当前 policy
  • 一旦更新,数据立刻“过期”
  • ❌ 每次更新都要重新采样
  • ❌ 样本利用率低
  • ❌ 在真实环境中成本很高

能不能用旧数据?

假设我们有一批轨迹:

\[\tau \sim \pi_{\text{old}}\]

但我们想优化的是:

\[\pi_\theta\]

用 importance sampling 改写

我们可以写:

\[\mathbb{E}_{\tau \sim \pi_\theta} = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \left[ \frac{\pi_\theta(\tau)}{\pi_{\text{old}}(\tau)} \right]\]

由于 trajectory 概率可以分解:

\[\frac{\pi_\theta(\tau)}{\pi_{\text{old}}(\tau)} = \prod_{t=0}^{T-1} \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\text{old}}(a_t \mid s_t)}\]

为了降低方差,我们通常使用 逐步(per-step)近似

定义:

\[r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\text{old}}(a_t \mid s_t)}\]

得到新的目标

\[\mathcal{L}(\theta) = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \sum_{t=0}^{T-1} \left[ r_t(\theta)\ \hat{A}_t \right]\]

但是 ratio 会爆炸,方差太大了

  • 如果新策略更偏向某个 action \(r_t > 1 \\)
  • 如果不偏向 \(r_t < 1 \\)

这会导致:

  • ❌ 梯度极不稳定
  • ❌ 方差巨大
  • ❌ policy collapse

PPO

不要完全相信这个 ratio,把它限制住

公式

\[\mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E}_{\tau \sim \pi_{\text{old}}} \sum_{t=0}^{T-1} \left[ \min\left( r_t(\theta)\hat{A}_t,\ \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t \right) \right]\]

我们分两种情况看。


情况 1:(\hat{A}_t > 0)(好动作)

我们希望:

  • 增大该动作概率
  • 即 ( r_t > 1 )

但:

  • 如果 ( r_t \gg 1 ),更新会过于激进
  • PPO 允许增加,但最多到 (1 + \epsilon)
  • 超过之后不再继续增大

情况 2:(\hat{A}_t < 0)(坏动作)

我们希望:

  • 降低该动作概率
  • 即 ( r_t < 1 )

但:

  • 如果 ( r_t \ll 1 ),下降过猛
  • PPO 允许下降,但最低到 (1 - \epsilon)
  • 再低就停止惩罚

直觉总结

PPO 限制了每一步 policy 更新的幅度 PPO 不是在修正梯度
而是在控制“更新步长”

虽然我们使用了旧数据:

  • 只使用最近一批数据
  • 只更新有限次(multiple epochs)

所以:

PPO 更像是一个 带有 off-policy 形式的 on-policy 方法


实际使用的目标函数

工程上我们使用:

\[\mathcal{L}(\theta) = \mathcal{L}^{\text{CLIP}}(\theta) - c_1 \mathcal{L}^{\text{value}} + c_2 \mathcal{H}(\pi_\theta)\]

其中:

  • value loss:拟合 (V(s_t))
  • entropy:鼓励探索

代码(PyTorch)

import torch


def ppo_loss(old_log_prob, new_log_prob, values, returns, eps, mask):
    """
    old_log_prob: [batch_size, seq_len]
        旧策略下,已采样动作的 log π_old(a_t | s_t)

    new_log_prob: [batch_size, seq_len]
        新策略下,已采样动作的 log π_θ(a_t | s_t)

    values: [batch_size, seq_len]
        critic 输出的 V_φ(s_t)

    returns: [batch_size, seq_len]
        由采样轨迹计算得到的 G_t

    eps: float
        PPO clipping 系数,通常取 0.1 / 0.2

    mask: [batch_size, seq_len]
        有效位置 mask
    """
    advantage = returns - values
    advantage_detached = advantage.detach()

    ratio = torch.exp(new_log_prob - old_log_prob)

    surr1 = ratio * advantage_detached
    surr2 = torch.clamp(ratio, 1 - eps, 1 + eps) * advantage_detached

    actor_loss = -(torch.min(surr1, surr2) * mask).sum() / mask.sum()
    critic_loss = (((values - returns.detach()) ** 2) * mask).sum() / mask.sum()

    return actor_loss, critic_loss