Discrete-adjoint differentiation¶
Overview¶
vmec_jax computes exact Jacobians of any differentiable scalar functional (quasisymmetry residuals, aspect ratio, etc.) with respect to boundary shape parameters using a discrete-adjoint technique.
In contrast to the classical finite-difference approach used by SIMSOPT + VMEC2000, discrete-adjoint differentiation:
is exact — derivatives are computed to floating-point precision, not truncated by step-size selection;
requires only one checkpoint replay rather than one full forward solve per parameter;
scales with number of output quantities, not number of input parameters — ideal when the parameter space is large.
Background: what are we differentiating?¶
The VMEC iteration produces a converged equilibrium state
\(\boldsymbol{q}^* \in \mathbb{R}^N\) (packed array of R, Z,
λ Fourier coefficients on the VMEC half/full mesh) by applying the
iterative map
from an initial guess \(\boldsymbol{q}_0(p)\) that depends on the boundary parameters \(p \in \mathbb{R}^m\). At convergence, \(\boldsymbol{q}^*(p)\) satisfies the fixed-point equation
A scalar objective \(f = \ell(\boldsymbol{q}^*(p), p)\) is then evaluated from the converged state. We want \(\partial f / \partial p_i\).
The challenge: \(\Phi\) is a Fortran-style scan loop — a sequence of ~5 000 composite iterations — and automatic differentiation through 5 000 unrolled loop iterations would require storing the full trajectory in memory and would produce enormous computation graphs.
The discrete-adjoint approach¶
vmec_jax uses a two-pass strategy, analogous to the adjoint method in optimal control:
Forward pass (checkpoint tape)
The forward solve runs normally, but at every N_checkpoint-th iteration
a snapshot of the VMEC state is stored to a compact in-memory checkpoint
tape. Between checkpoints, the state is not stored — it is recomputed on
demand during the backward pass. This trades memory for recomputation.
Forward pass
────────────────────────────────────────────
q₀(p) ──→ q₁ ──→ … ──→ q_{c₁} ┐ checkpoint
└──→ … ──→ q_{c₂} ┐ checkpoint
└──→ … ──→ q*(p)
The checkpoints {q_{c₁}, q_{c₂}, ...} are stored.
All other iterates are discarded.
Tangent propagation (JVP replay)
For each boundary parameter \(p_i\), the tangent vector \(\partial \boldsymbol{q}_0 / \partial p_i\) is propagated forward through the tape using Jacobian-vector products (JVPs):
Because JAX traces the iterative map \(\Phi\) as a JAX program, JVPs are
available via jax.jvp with no extra code. All \(m\) tangents are
propagated simultaneously using jax.vmap(jax.jvp(Φ, ...)) — a single
batched JVP that visits each checkpoint interval exactly once.
This gives the full Jacobian column batch \(\partial \boldsymbol{q}^* / \partial p_i\) for all \(i\) in \(O(m)\) JVPs, which is roughly equivalent to 1–2 forward solves regardless of \(m\).
Objective linearization
Finally, the Jacobian of the objective with respect to the final state is applied:
using one more jax.jvp call on the residuals function.
The result is the exact (machine-precision) dense Jacobian matrix \(J \in \mathbb{R}^{n_r \times m}\) where \(n_r\) is the number of residuals and \(m\) is the number of boundary DOFs.
Implementation in vmec_jax¶
The key functions live in vmec_jax/discrete_adjoint.py:
Function |
Role |
|---|---|
|
Run forward solve; store checkpoints. Returns a
|
|
Given the tape and a batch of parameter tangents, propagate all tangents through the tape. Returns \(\partial q^* / \partial p\) (columns stacked). |
|
Single-parameter JVP (used internally). |
|
Reverse-mode (VJP) for scalar loss functions — cheaper than the forward-mode columns when \(n_r = 1\) (e.g., single scalar objective). |
The FixedBoundaryExactOptimizer in
vmec_jax/optimization.py orchestrates everything:
Call
build_residual_checkpoint_tape()with the tight forward-solve settings.Propagate boundary tangents via
checkpoint_tape_state_jvp_columns().Multiply by the residuals Jacobian to form \(J = \partial r / \partial p\).
Solve the Gauss-Newton normal equations \(J^T J\, \Delta p = -J^T r\) via LAPACK
dgelsd.Armijo backtracking line search (relaxed forward solve at trial points).
Cache-hit detection: if the next call to
residual_funis at the same \(p\) as the last tape build, reuse the tape.
Dynamic replay bucketing¶
The tape length \(K\) (number of VMEC iterations to convergence) varies slightly from one Gauss-Newton step to the next. A different \(K\) would trigger XLA recompilation of the replay scan.
vmec_jax pads short tapes to the nearest multiple of
VMEC_JAX_DYNAMIC_REPLAY_BUCKET so that the same compiled XLA kernel is
reused across nearby steps without padding every run to an overly long replay.
The default is backend-adaptive: 32 iterations on CPU and 128 on
CUDA/ROCm/GPU backends. The larger GPU default reduces replay recompilation
for the accepted-point exact-Jacobian path; CPU profiling still favors the
smaller bucket.
export VMEC_JAX_DYNAMIC_REPLAY_BUCKET=16 # finer bucketing
export VMEC_JAX_DYNAMIC_REPLAY_BUCKET=128 # coarser bucketing
Large buckets can reduce recompiles for some long trajectories, but they can also make each replay substantially more expensive. Treat this variable as a profiling control rather than a recommended user setting.
Comparison with other approaches¶
Property |
vmec_jax (discrete-adjoint) |
SIMSOPT + VMEC2000 (FD) |
Continuous adjoint (DESC) |
|---|---|---|---|
Jacobian cost |
≈ 1–2 × forward solve |
m × forward solve |
1 × backward solve |
Accuracy |
Machine precision |
\(O(\sqrt{\varepsilon_\text{machine}})\) FD error |
Machine precision |
Memory |
O(checkpoint_interval × state) |
O(1) |
O(state) |
Subprocess dependency |
None (pure Python/JAX) |
Fortran binary required |
None (Python/JAX) |
Differentiable through solver? |
Yes (JAX autodiff) |
No |
Yes |
GPU support |
Yes |
No |
Yes |
Continuous adjoint (DESC): DESC [Dudt et al., 2023]_ builds a continuous PDE adjoint for the MHD equilibrium equations, solving the adjoint problem exactly once using a Newton-Krylov solver. The cost is one backward solve (same order as one forward solve). vmec_jax’s discrete-adjoint replays the iteration tape instead of solving a continuous adjoint equation, and is therefore directly applicable to VMEC’s fixed-point iteration without reformulating the equations.
Implicit differentiation (IFT): an alternative is to differentiate the
fixed-point equation implicitly via the implicit function theorem (IFT).
vmec_jax provides solve_fixed_boundary_state_implicit() for
this path. It requires solving a linear system \((I - \partial\Phi/\partial q)\,v = b\)
which is approximated via CG + JVP. The discrete-adjoint tape replay avoids
this linear solve entirely and is the default in
FixedBoundaryExactOptimizer.
See also
Optimisation with vmec_jax — practical guide to running vmec_jax optimizations
Comparison with SIMSOPT — detailed runtime and accuracy comparison with SIMSOPT
build_residual_checkpoint_tape()checkpoint_tape_state_jvp_columns()FixedBoundaryExactOptimizer