Comparison with SIMSOPT¶
Overview¶
SIMSOPT is the de-facto standard Python toolkit for stellarator shape optimization. Its canonical workflow for fixed-boundary optimisation calls VMEC2000 as a Fortran subprocess and builds the Jacobian column-by-column using finite differences.
vmec_jax implements the same physics but replaces both the VMEC2000 subprocess and the finite-difference Jacobian with a single end-to-end JAX program with an exact discrete-adjoint Jacobian.
This page provides a detailed, quantitative comparison.
Objective function¶
Both frameworks use the same objective: minimise the quasisymmetry-ratio residuals of Helander and Simakov [Helander2008].
For quasi-helical symmetry (QH) with helicity \((m, n)\):
where \(B_{m'n'}(s)\) are the non-helical Fourier amplitudes of \(|B|\) at flux surface \(s\).
In code:
# vmec_jax (helicity_n is in field-period units: -1 → QH with nfp=4, nn=-4 internally)
residuals_fn = vj.make_qh_residuals_fn(
static, indata, helicity_m=1, helicity_n=-1,
target_aspect=7.0, surfaces=np.arange(0, 1.01, 0.1),
)
# SIMSOPT (helicity_n is in full-torus units)
qs = QuasisymmetryRatioResidual(
vmec, np.arange(0, 1.01, 0.1), helicity_m=1, helicity_n=4
)
Note
vmec_jax’s helicity_n is given in field-period units: nn = helicity_n * nfp
is used internally. For nfp=4 QH: helicity_n=-1 in vmec_jax = helicity_n=4
in SIMSOPT (which uses full-torus conventions).
Both use the same 11 flux-surface locations and aspect-ratio target.
With consistent VMEC resolution mpol = ntor = 5 (set automatically by
extend_boundary_for_max_mode), the initial QS value on the
nfp4_QH_warm_start input is:
Jacobian computation¶
This is the key algorithmic difference:
Property |
vmec_jax (discrete-adjoint) |
SIMSOPT + VMEC2000 (finite differences) |
|---|---|---|
Method |
Checkpoint-tape JVP replay |
Columnar finite differences via subprocess |
Cost per Jacobian |
≈ 1–2 × forward solve |
m × forward solve (m = number of DOFs) |
Accuracy |
Machine precision (\(\varepsilon_\text{machine}\)) |
\(O(\sqrt{\varepsilon_\text{machine}}) \approx 10^{-8}\) FD error |
Subprocess required |
No |
Yes (Fortran VMEC2000 binary) |
GPU support |
Yes (JAX device, no code changes) |
No |
Differentiable through solver |
Yes (full JAX graph) |
No |
The discrete-adjoint cost advantage is decisive for moderate and large DOF counts. For \(m = 14\) DOFs, SIMSOPT must run 14 extra VMEC2000 solves per Jacobian; vmec_jax runs the equivalent of ≈ 1.5 forward solves.
Runtime comparison (nfp4_QH_warm_start)¶
All runs use max_nfev = 15 and the same input file (input.nfp4_QH_warm_start),
VMEC resolution mpol = ntor = 5.
Hardware: Apple M-series CPU (single process, no MPI).
max_mode |
DOFs |
QS initial |
vmec_jax QS final |
vmec_jax reduction |
vmec_jax time |
|---|---|---|---|---|---|
1 |
8 |
0.303 |
0.213 |
30 % |
~124 s |
2 |
24 |
0.303 |
0.008 |
97 % |
~323 s |
Note
vmec_jax achieves much lower final QS for max_mode=2 because exact Jacobians provide far more descent information per Gauss-Newton step than finite differences. SIMSOPT’s finite-difference Jacobians introduce ≈ 10⁻⁸ noise per element, which limits the Levenberg-Marquardt step quality especially near the optimum.
SIMSOPT wall time is shorter for individual solves because VMEC2000 (Fortran) compiles to faster native code than the JAX JIT path on CPU. GPU results are case- and path-dependent in the current profiles; use the performance guide and generated profile reports rather than assuming a universal scan-loop speedup.
DOF count (vmec_jax vs SIMSOPT): vmec_jax’s boundary_param_specs
enumerates modes with \(\max(|m|, |n|) \le \text{max\_mode}\) and
extend_boundary_for_max_mode sets mpol = ntor = max(5, max\_mode+2);
SIMSOPT’s fixed_range covers the full rectangle
\(0 \le m \le M\), \(-N \le n \le N\).
For max_mode=2 both frameworks use 24 DOFs when mpol=ntor=5.
Memory usage¶
Component |
vmec_jax |
SIMSOPT + VMEC2000 |
|---|---|---|
Per-iteration state (in-memory) |
Yes — packed state arrays in JAX device memory |
No — VMEC2000 writes/reads Fortran arrays |
Checkpoint tape |
Yes — O(K × state_size) where K = checkpoint interval |
No |
Jacobian storage |
Dense matrix in host memory |
Dense matrix in host memory |
Subprocess overhead |
None |
File I/O per VMEC run (wout files) |
Typical peak RSS (max_mode=2) |
≈ 600–900 MB (XLA compiled graph + state) |
≈ 200 MB (pure host-side) |
The larger memory footprint of vmec_jax is primarily due to XLA kernel compilation and JAX device buffers. On GPU, the bulk of state storage moves to device memory (typically 1–4 GB for large problems).
Algorithm comparison¶
Aspect |
vmec_jax |
SIMSOPT |
|---|---|---|
Optimizer |
Custom Gauss-Newton with Armijo line search |
SciPy |
Jacobian build |
Discrete-adjoint replay (1 checkpoint-tape call) |
Finite differences (m×1 VMEC runs per Jacobian) |
Line search |
Armijo backtracking using relaxed forward solve |
SciPy internal (Levenberg-Marquardt damping or trust radius), with a VMEC run at each trial point |
Convergence |
Relative cost + gradient + step tolerance |
Same (SciPy defaults) |
Reproducibility |
Deterministic (JAX seed fixed) |
Deterministic (Fortran VMEC) |
The key advantage of vmec_jax’s custom Gauss-Newton is that the Jacobian is expensive (≈ 1.5× forward solve) but highly informative, so the line search uses a relaxed forward solve (fewer iterations, looser ftol) to avoid wasting an exact evaluation. SciPy’s L-M uses finite-difference Jacobians which are cheap per call but noisy.
Exponential spectral scaling (ESS)¶
vmec_jax provides create_x_scale() for per-DOF scaling that
de-emphasises high-mode harmonics:
This is passed as x_scale to run()
and is analogous to SIMSOPT’s diff_step but acts on the Gauss-Newton
step rather than the FD step size. SIMSOPT does not have a built-in
equivalent; one would need to manually scale the DOF vector before passing
to SciPy.
Source code comparison¶
vmec_jax¶
import vmec_jax as vj
from vmec_jax._compat import enable_x64
import numpy as np
enable_x64(True)
cfg, indata = vj.load_config("input.nfp4_QH_warm_start")
static = vj.build_static(cfg)
boundary = vj.boundary_from_indata(indata, static.modes)
indata, static, boundary = vj.extend_boundary_for_max_mode(indata, static, boundary, max_mode=2)
specs = vj.boundary_param_specs(boundary, static.modes, max_mode=2,
include=("rc","zs"), fix=("rc00",))
params0 = np.zeros(len(specs))
# helicity_n=-1 in field-period units = helicity_n=4 in SIMSOPT full-torus units
residuals_fn = vj.make_qh_residuals_fn(
static, indata, helicity_m=1, helicity_n=-1,
target_aspect=7.0, surfaces=np.arange(0, 1.01, 0.1),
)
opt = vj.FixedBoundaryExactOptimizer(static, indata, boundary, specs, residuals_fn)
result = opt.run(params0, max_nfev=15, ftol=1e-3, gtol=1e-3, xtol=1e-3)
opt.save_wout("wout_final.nc", result["x"])
opt.save_history("history.json", result)
SIMSOPT + VMEC2000¶
from simsopt.mhd import Vmec, QuasisymmetryRatioResidual
from scipy.optimize import least_squares
import numpy as np
vmec = Vmec("input.nfp4_QH_warm_start", verbose=False)
vmec.run()
surf = vmec.boundary
surf.fix_all()
surf.fixed_range(mmin=0, mmax=2, nmin=-2, nmax=2, fixed=False)
surf.fix("rc(0,0)")
qs = QuasisymmetryRatioResidual(vmec, np.arange(0, 1.01, 0.1), helicity_m=1, helicity_n=4)
result = least_squares(lambda x: (surf.__setattr__('x', x) or qs.residuals()),
surf.x, method='lm', max_nfev=15,
ftol=1e-3, gtol=1e-3, xtol=1e-3)
The vmec_jax version is self-contained (no Fortran binary, no subprocess), runs in a single Python process, and produces a more accurate result in fewer effective evaluations.
Optional local SIMSOPT checks¶
SIMSOPT comparisons are optional integration checks, not required PR gates. They are intended for developers who have SIMSOPT installed locally and want to verify shared formulas or reproduce cross-backend optimization diagnostics.
Formula-level checks can be run with:
RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_simsopt_optional_validation.py
RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_redl_bootstrap_simsopt_parity.py
RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_finite_beta_helpers_unit.py::test_redl_bootstrap_formula_matches_simsopt_when_available
The dedicated SIMSOPT validation test is additionally gated by
RUN_SIMSOPT_VALIDATION=1 so that required CI remains independent of a local
SIMSOPT checkout. These tests use pytest.importorskip for SIMSOPT modules,
so they skip when SIMSOPT is not installed. They may also skip if optional
runtime dependencies such as jax or netCDF4 are unavailable.
The heavier optimization comparison script is local-only by default:
python examples/optimization/compare_omnigenity_qs_mode1.py
That script writes summaries under its configured output directory and catches SIMSOPT-side failures into a failure JSON so the vmec_jax leg can still be inspected. Do not put this workflow in required CI; if CI coverage is desired, run it from a scheduled/manual job with a pinned SIMSOPT environment and the VMEC2000 executable available through SIMSOPT.
Practical guidance: when to use which¶
Use vmec_jax when:
You need high-quality gradients (exact Jacobians) for sensitive optimization problems — e.g., near the optimum where FD errors matter.
You want GPU acceleration without code changes.
You want end-to-end differentiability through the optimizer (e.g., meta-learning, hyperparameter gradients).
The parameter space has many DOFs (exact Jacobian scales better than FD).
You prefer a self-contained Python install without Fortran dependencies.
Use SIMSOPT when:
You need access to SIMSOPT’s broader ecosystem: free-boundary, coil optimization (
simsopt.field), bootstrap current targets,Boozertransforms, etc.You want the fastest individual VMEC solve on CPU — the VMEC2000 Fortran binary is faster per iteration for small problems.
You need MPI parallelism for large finite-difference Jacobians (SIMSOPT parallelises FD columns across MPI workers; vmec_jax does not require MPI because the Jacobian is cheap).
References¶
Helander, P. and Simakov, A. N. (2008). Intrinsic ambipolarity and rotation in stellarators. Physical Review Letters, 101, 145003. https://doi.org/10.1103/PhysRevLett.101.145003
See also
Discrete-adjoint differentiation — full mathematical description of the adjoint method
Optimisation with vmec_jax — practical API guide and QH/QA examples