How Do We Improve a Policy?

Suppose we have a policy and roll it out, recording the states, actions, and rewards encountered. Some actions worked out, increasing our return, and others didn't. To improve our policy, we want to increase the likeliness of taking the "good" actions in the future, and to make the "bad" actions less likely. But how do we turn that into a gradient we can backpropagate for the policy?

The Policy Gradient Theorem gives us the answer. For each action in our rollout, we multiply two things together: the direction that makes the action more likely, and a score for how good the action was. Summing over all timesteps gives us the gradient of our expected return. Below the fancy mathematical equation for that:

\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot A_t \right]\]

The first term, \(\nabla_\theta \log \pi_\theta(a_t | s_t)\), points in the direction that increases the probability of action \(a_t\). The second term, \(A_t\), is the advantage - how much better this action was compared to what the policy would normally do. Positive advantage means "do more of this", negative means "do less".

But how do we actually compute the advantage? By definition, it is the difference between the value of taking a specific action and the average value across all actions:

\[A(s_t, a_t) = Q(s_t, a_t) - V(s_t)\]

Here \(Q(s_t, a_t)\) is the expected return from taking action \(a_t\) in state \(s_t\), and \(V(s_t) = \mathbb{E}_{a \sim \pi}[Q(s_t, a)]\) is the expected return averaged over all actions the policy might take. We have a learned estimate of \(V\), but we need to estimate \(Q\) from the rewards we actually observed in our rollout. How we do that estimation is where things get interesting.

The Estimation Problem: Bias vs Variance

Computing the advantage requires comparing what rewards we actually received in the rollout (\(\sum_{k=0}^{\infty} \gamma^k r_{t+k}\)) against what rewards we expected to receive (\(V(s_t)\)). The quality of our advantage estimate depends entirely on how we estimate this expected value, and there are two classical approaches:

  • Monte Carlo only uses the rewards collected in the rollout for the Q-value target ensuring that it correct on average (i.e., unbiased) but each rollout's rewards are noisy (from the randomness in the policy and environment) means that different rollouts give very different returns, resulting in high variance.
  • Temporal Difference (TD) Error replaces the future rewards after \(t+1\) with estimate of their expected sum \(V(s_{t+1})\). This compresses all that noise across rollouts into a stable, deterministic output, giving low variance. But if the learned \(V\) is inaccurate or systematically shifted from the correct scale it will be biased.
Method Formula Bias Variance
Monte Carlo \(A_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k} - V(s_t)\) Unbiased High
TD Error \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) Biased Low

GAE: Balancing the Tradeoff

Generalized Advantage Estimation (GAE), introduced by Schulman et al. (2016), provides a solution to the trade-off between the variance and bias of Monte Carlo and TD Error. It is parameterized by \(\lambda \in [0, 1]\) which controls how much we lean towards TD (low variance, higher bias) or Monte Carlo (unbiased, higher variance):

\(\lambda = 0\)

\(A_t = \delta_t\)

Pure TD error
Low variance High bias

\(\longleftrightarrow\)

GAE(\(\lambda\))

\(\lambda = 1\)

\(A_t = \sum_k \gamma^k r_{t+k} - V(s_t)\)

Monte Carlo-like
High variance Low bias

The key insight of GAE is instead of using just the immediate 1-step TD error (which is biased but stable) or the full Monte Carlo return (which is unbiased but noisy), GAE takes a weighted average of all future TD errors. The parameter \(\lambda\) controls how quickly we discount future errors: with \(\lambda = 0\), we only use the immediate error, while \(\lambda = 1\) uses all future errors equally, recovering something close to Monte Carlo.

GAE Formula

GAE takes an exponentially-weighted sum of TD errors \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\), with weights decaying by \(\gamma\lambda\) at each step:

\[A_t^{\text{GAE}(\gamma, \lambda)} = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k}\]

In practice, we efficiently compute this backward across the trajectory, starting from the last timestep (where \(A_T = \delta_T\)) and working backwards:

\[A_t = \delta_t + \gamma \lambda \cdot A_{t+1}\]

Each advantage incorporates the TD error at that step plus a discounted, decayed version of future advantages. This is exactly what you saw in the visualisation above.

Interactive Rollout Visualisation

Watch the backward computation unfold step by step. Adjust \(\gamma\) and \(\lambda\) to see how they affect the advantage estimates.

0.99
0.95
Computation Steps

Implementation

Here's how GAE is typically implemented in practice. The key insight is computing advantages in reverse order. We also need to handle episode boundaries correctly, since a rollout might span multiple episodes.

Episodes can end in two ways: termination (a true terminal state, e.g. death or reaching the goal) or truncation (a time limit was hit). These require different treatment:

  • On termination, there is no next state, so the bootstrap value \(V(s_{t+1})\) should be zero.
  • On truncation, the agent could have continued, so we still bootstrap with \(V(s_{t+1})\).

In both cases, we reset the backward pass at the episode boundary so that advantages from one episode don't bleed into another.

import numpy as np

def compute_gae(rewards, values, terminated, truncated, gamma, lambda_):
    # values has length T+1 (includes bootstrap value)
    # rewards, terminated, truncated have length T
    T = len(rewards)
    advantages = np.zeros(T)
    gae = 0

    # Terminated: no future value. Truncated: bootstrap as normal.
    next_values = np.where(terminated, 0, values[1:])
    deltas = rewards + gamma * next_values - values[:-1]

    # Reset at any episode boundary (terminated or truncated)
    episode_over = terminated | truncated

    for t in reversed(range(T)):
        gae = deltas[t] + gamma * lambda_ * (~episode_over[t]) * gae
        advantages[t] = gae

    return advantages

Practical Considerations

Choosing \(\lambda\)

In practice, \(\lambda\) values between 0.9 and 0.99 work well for most problems. The PPO paper uses \(\lambda = 0.95\) as a default. Higher values are better when your value function is inaccurate; lower values help when you have a good value function and want to reduce variance.

Interaction with \(\gamma\)

The effective discount for the advantage weights is \(\gamma\lambda\). Even with \(\lambda = 1\), using \(\gamma < 1\) still provides some exponential decay. Common choices are \(\gamma = 0.99\) with \(\lambda = 0.95\), giving an effective decay of \(0.99 \times 0.95 = 0.9405\).

Why GAE Works

GAE succeeds because it lets you control how much you trust your value function versus the raw sampled returns. A lower \(\lambda\) leans on \(V\) more (stable but potentially biased); a higher \(\lambda\) relies on actual rewards (unbiased but noisy). Tuning \(\lambda\) finds the right balance for your specific problem.

References

Schulman, J., Moritz, P., Levine, S., Jordan, M., & Abbeel, P. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation. ICLR 2016.