Discrete-adjoint differentiation

Overview

vmec_jax computes exact Jacobians of any differentiable scalar functional (quasisymmetry residuals, aspect ratio, etc.) with respect to boundary shape parameters using a discrete-adjoint technique.

In contrast to the classical finite-difference approach used by SIMSOPT + VMEC2000, discrete-adjoint differentiation:

  • is exact — derivatives are computed to floating-point precision, not truncated by step-size selection;

  • requires only one checkpoint replay rather than one full forward solve per parameter;

  • scales with number of output quantities, not number of input parameters — ideal when the parameter space is large.

Background: what are we differentiating?

The VMEC iteration produces a converged equilibrium state \(\boldsymbol{q}^* \in \mathbb{R}^N\) (packed array of R, Z, λ Fourier coefficients on the VMEC half/full mesh) by applying the iterative map

\[\boldsymbol{q}_{k+1} = \Phi(\boldsymbol{q}_k,\, p)\]

from an initial guess \(\boldsymbol{q}_0(p)\) that depends on the boundary parameters \(p \in \mathbb{R}^m\). At convergence, \(\boldsymbol{q}^*(p)\) satisfies the fixed-point equation

\[\boldsymbol{q}^*(p) = \Phi(\boldsymbol{q}^*(p),\, p).\]

A scalar objective \(f = \ell(\boldsymbol{q}^*(p), p)\) is then evaluated from the converged state. We want \(\partial f / \partial p_i\).

The challenge: \(\Phi\) is a Fortran-style scan loop — a sequence of ~5 000 composite iterations — and automatic differentiation through 5 000 unrolled loop iterations would require storing the full trajectory in memory and would produce enormous computation graphs.

The discrete-adjoint approach

vmec_jax uses a two-pass strategy, analogous to the adjoint method in optimal control:

Forward pass (checkpoint tape)

The forward solve runs normally, but at every N_checkpoint-th iteration a snapshot of the VMEC state is stored to a compact in-memory checkpoint tape. Between checkpoints, the state is not stored — it is recomputed on demand during the backward pass. This trades memory for recomputation.

Forward pass
────────────────────────────────────────────
q₀(p) ──→ q₁ ──→ … ──→ q_{c₁} ┐ checkpoint
                                └──→ … ──→ q_{c₂} ┐ checkpoint
                                                    └──→ … ──→ q*(p)

The checkpoints {q_{c₁}, q_{c₂}, ...} are stored.
All other iterates are discarded.

Tangent propagation (JVP replay)

For each boundary parameter \(p_i\), the tangent vector \(\partial \boldsymbol{q}_0 / \partial p_i\) is propagated forward through the tape using Jacobian-vector products (JVPs):

\[\frac{\partial \boldsymbol{q}_{k+1}}{\partial p_i} = \frac{\partial \Phi}{\partial \boldsymbol{q}_k} \cdot \frac{\partial \boldsymbol{q}_k}{\partial p_i} + \frac{\partial \Phi}{\partial p_i}\]

Because JAX traces the iterative map \(\Phi\) as a JAX program, JVPs are available via jax.jvp with no extra code. All \(m\) tangents are propagated simultaneously using jax.vmap(jax.jvp(Φ, ...)) — a single batched JVP that visits each checkpoint interval exactly once.

This gives the full Jacobian column batch \(\partial \boldsymbol{q}^* / \partial p_i\) for all \(i\) in \(O(m)\) JVPs, which is roughly equivalent to 1–2 forward solves regardless of \(m\).

Objective linearization

Finally, the Jacobian of the objective with respect to the final state is applied:

\[\frac{\partial f}{\partial p_i} = \frac{\partial \ell}{\partial \boldsymbol{q}^*} \cdot \frac{\partial \boldsymbol{q}^*}{\partial p_i} + \frac{\partial \ell}{\partial p_i}\]

using one more jax.jvp call on the residuals function.

The result is the exact (machine-precision) dense Jacobian matrix \(J \in \mathbb{R}^{n_r \times m}\) where \(n_r\) is the number of residuals and \(m\) is the number of boundary DOFs.

Implementation in vmec_jax

The key functions live in vmec_jax/discrete_adjoint.py:

Function

Role

build_residual_checkpoint_tape()

Run forward solve; store checkpoints. Returns a ResidualCheckpointTape.

checkpoint_tape_state_jvp_columns()

Given the tape and a batch of parameter tangents, propagate all tangents through the tape. Returns \(\partial q^* / \partial p\) (columns stacked).

checkpoint_tape_param_jvp()

Single-parameter JVP (used internally).

checkpoint_tape_state_vjp()

Reverse-mode (VJP) for scalar loss functions — cheaper than the forward-mode columns when \(n_r = 1\) (e.g., single scalar objective).

The FixedBoundaryExactOptimizer in vmec_jax/optimization.py orchestrates everything:

  1. Call build_residual_checkpoint_tape() with the tight forward-solve settings.

  2. Propagate boundary tangents via checkpoint_tape_state_jvp_columns().

  3. Multiply by the residuals Jacobian to form \(J = \partial r / \partial p\).

  4. Solve the Gauss-Newton normal equations \(J^T J\, \Delta p = -J^T r\) via LAPACK dgelsd.

  5. Armijo backtracking line search (relaxed forward solve at trial points).

  6. Cache-hit detection: if the next call to residual_fun is at the same \(p\) as the last tape build, reuse the tape.

Dynamic replay bucketing

The tape length \(K\) (number of VMEC iterations to convergence) varies slightly from one Gauss-Newton step to the next. A different \(K\) would trigger XLA recompilation of the replay scan.

vmec_jax pads short tapes to the nearest multiple of VMEC_JAX_DYNAMIC_REPLAY_BUCKET so that the same compiled XLA kernel is reused across nearby steps without padding every run to an overly long replay. The default is backend-adaptive: 32 iterations on CPU and 128 on CUDA/ROCm/GPU backends. The larger GPU default reduces replay recompilation for the accepted-point exact-Jacobian path; CPU profiling still favors the smaller bucket.

export VMEC_JAX_DYNAMIC_REPLAY_BUCKET=16    # finer bucketing
export VMEC_JAX_DYNAMIC_REPLAY_BUCKET=128   # coarser bucketing

Large buckets can reduce recompiles for some long trajectories, but they can also make each replay substantially more expensive. Treat this variable as a profiling control rather than a recommended user setting.

Comparison with other approaches

Property

vmec_jax (discrete-adjoint)

SIMSOPT + VMEC2000 (FD)

Continuous adjoint (DESC)

Jacobian cost

≈ 1–2 × forward solve

m × forward solve

1 × backward solve

Accuracy

Machine precision

\(O(\sqrt{\varepsilon_\text{machine}})\) FD error

Machine precision

Memory

O(checkpoint_interval × state)

O(1)

O(state)

Subprocess dependency

None (pure Python/JAX)

Fortran binary required

None (Python/JAX)

Differentiable through solver?

Yes (JAX autodiff)

No

Yes

GPU support

Yes

No

Yes

Continuous adjoint (DESC): DESC [Dudt et al., 2023]_ builds a continuous PDE adjoint for the MHD equilibrium equations, solving the adjoint problem exactly once using a Newton-Krylov solver. The cost is one backward solve (same order as one forward solve). vmec_jax’s discrete-adjoint replays the iteration tape instead of solving a continuous adjoint equation, and is therefore directly applicable to VMEC’s fixed-point iteration without reformulating the equations.

Implicit differentiation (IFT): an alternative is to differentiate the fixed-point equation implicitly via the implicit function theorem (IFT). vmec_jax provides solve_fixed_boundary_state_implicit() for this path. It requires solving a linear system \((I - \partial\Phi/\partial q)\,v = b\) which is approximated via CG + JVP. The discrete-adjoint tape replay avoids this linear solve entirely and is the default in FixedBoundaryExactOptimizer.

See also

  • Optimisation with vmec_jax — practical guide to running vmec_jax optimizations

  • Comparison with SIMSOPT — detailed runtime and accuracy comparison with SIMSOPT

  • build_residual_checkpoint_tape()

  • checkpoint_tape_state_jvp_columns()

  • FixedBoundaryExactOptimizer