Comparison with SIMSOPT

Overview

SIMSOPT is the de-facto standard Python toolkit for stellarator shape optimization. Its canonical workflow for fixed-boundary optimisation calls VMEC2000 as a Fortran subprocess and builds the Jacobian column-by-column using finite differences.

vmec_jax implements the same physics but replaces both the VMEC2000 subprocess and the finite-difference Jacobian with a single end-to-end JAX program with an exact discrete-adjoint Jacobian.

This page provides a detailed, quantitative comparison.

Objective function

Both frameworks use the same objective: minimise the quasisymmetry-ratio residuals of Helander and Simakov [Helander2008].

For quasi-helical symmetry (QH) with helicity \((m, n)\):

\[f_{\rm QS}(p) = \sum_{s} \sum_{m',n'} \bigl[ B_{m'n'}(s) \bigr]^2\]

where \(B_{m'n'}(s)\) are the non-helical Fourier amplitudes of \(|B|\) at flux surface \(s\).

In code:

# vmec_jax  (helicity_n is in field-period units: -1 → QH with nfp=4, nn=-4 internally)
residuals_fn = vj.make_qh_residuals_fn(
    static, indata, helicity_m=1, helicity_n=-1,
    target_aspect=7.0, surfaces=np.arange(0, 1.01, 0.1),
)

# SIMSOPT  (helicity_n is in full-torus units)
qs = QuasisymmetryRatioResidual(
    vmec, np.arange(0, 1.01, 0.1), helicity_m=1, helicity_n=4
)

Note

vmec_jax’s helicity_n is given in field-period units: nn = helicity_n * nfp is used internally. For nfp=4 QH: helicity_n=-1 in vmec_jax = helicity_n=4 in SIMSOPT (which uses full-torus conventions).

Both use the same 11 flux-surface locations and aspect-ratio target. With consistent VMEC resolution mpol = ntor = 5 (set automatically by extend_boundary_for_max_mode), the initial QS value on the nfp4_QH_warm_start input is:

\[f_{\rm QS,0} \approx 0.303 \quad \text{(vmec\_jax, mpol=ntor=5)}\]

Jacobian computation

This is the key algorithmic difference:

Property

vmec_jax (discrete-adjoint)

SIMSOPT + VMEC2000 (finite differences)

Method

Checkpoint-tape JVP replay

Columnar finite differences via subprocess

Cost per Jacobian

≈ 1–2 × forward solve

m × forward solve (m = number of DOFs)

Accuracy

Machine precision (\(\varepsilon_\text{machine}\))

\(O(\sqrt{\varepsilon_\text{machine}}) \approx 10^{-8}\) FD error

Subprocess required

No

Yes (Fortran VMEC2000 binary)

GPU support

Yes (JAX device, no code changes)

No

Differentiable through solver

Yes (full JAX graph)

No

The discrete-adjoint cost advantage is decisive for moderate and large DOF counts. For \(m = 14\) DOFs, SIMSOPT must run 14 extra VMEC2000 solves per Jacobian; vmec_jax runs the equivalent of ≈ 1.5 forward solves.

Runtime comparison (nfp4_QH_warm_start)

All runs use max_nfev = 15 and the same input file (input.nfp4_QH_warm_start), VMEC resolution mpol = ntor = 5. Hardware: Apple M-series CPU (single process, no MPI).

max_mode

DOFs

QS initial

vmec_jax QS final

vmec_jax reduction

vmec_jax time

1

8

0.303

0.213

30 %

~124 s

2

24

0.303

0.008

97 %

~323 s

Note

vmec_jax achieves much lower final QS for max_mode=2 because exact Jacobians provide far more descent information per Gauss-Newton step than finite differences. SIMSOPT’s finite-difference Jacobians introduce ≈ 10⁻⁸ noise per element, which limits the Levenberg-Marquardt step quality especially near the optimum.

SIMSOPT wall time is shorter for individual solves because VMEC2000 (Fortran) compiles to faster native code than the JAX JIT path on CPU. GPU results are case- and path-dependent in the current profiles; use the performance guide and generated profile reports rather than assuming a universal scan-loop speedup.

DOF count (vmec_jax vs SIMSOPT): vmec_jax’s boundary_param_specs enumerates modes with \(\max(|m|, |n|) \le \text{max\_mode}\) and extend_boundary_for_max_mode sets mpol = ntor = max(5, max\_mode+2); SIMSOPT’s fixed_range covers the full rectangle \(0 \le m \le M\), \(-N \le n \le N\). For max_mode=2 both frameworks use 24 DOFs when mpol=ntor=5.

Memory usage

Component

vmec_jax

SIMSOPT + VMEC2000

Per-iteration state (in-memory)

Yes — packed state arrays in JAX device memory

No — VMEC2000 writes/reads Fortran arrays

Checkpoint tape

Yes — O(K × state_size) where K = checkpoint interval

No

Jacobian storage

Dense matrix in host memory

Dense matrix in host memory

Subprocess overhead

None

File I/O per VMEC run (wout files)

Typical peak RSS (max_mode=2)

≈ 600–900 MB (XLA compiled graph + state)

≈ 200 MB (pure host-side)

The larger memory footprint of vmec_jax is primarily due to XLA kernel compilation and JAX device buffers. On GPU, the bulk of state storage moves to device memory (typically 1–4 GB for large problems).

Algorithm comparison

Aspect

vmec_jax

SIMSOPT

Optimizer

Custom Gauss-Newton with Armijo line search

SciPy least_squares (Levenberg-Marquardt or trust-region reflective)

Jacobian build

Discrete-adjoint replay (1 checkpoint-tape call)

Finite differences (m×1 VMEC runs per Jacobian)

Line search

Armijo backtracking using relaxed forward solve

SciPy internal (Levenberg-Marquardt damping or trust radius), with a VMEC run at each trial point

Convergence

Relative cost + gradient + step tolerance

Same (SciPy defaults)

Reproducibility

Deterministic (JAX seed fixed)

Deterministic (Fortran VMEC)

The key advantage of vmec_jax’s custom Gauss-Newton is that the Jacobian is expensive (≈ 1.5× forward solve) but highly informative, so the line search uses a relaxed forward solve (fewer iterations, looser ftol) to avoid wasting an exact evaluation. SciPy’s L-M uses finite-difference Jacobians which are cheap per call but noisy.

Exponential spectral scaling (ESS)

vmec_jax provides create_x_scale() for per-DOF scaling that de-emphasises high-mode harmonics:

\[w_i = \exp(-\alpha \cdot \max(|m_i|, |n_i|)) \;/\; \exp(-\alpha)\]

This is passed as x_scale to run() and is analogous to SIMSOPT’s diff_step but acts on the Gauss-Newton step rather than the FD step size. SIMSOPT does not have a built-in equivalent; one would need to manually scale the DOF vector before passing to SciPy.

Source code comparison

vmec_jax

import vmec_jax as vj
from vmec_jax._compat import enable_x64
import numpy as np

enable_x64(True)

cfg, indata = vj.load_config("input.nfp4_QH_warm_start")
static       = vj.build_static(cfg)
boundary     = vj.boundary_from_indata(indata, static.modes)
indata, static, boundary = vj.extend_boundary_for_max_mode(indata, static, boundary, max_mode=2)

specs  = vj.boundary_param_specs(boundary, static.modes, max_mode=2,
                                 include=("rc","zs"), fix=("rc00",))
params0 = np.zeros(len(specs))

# helicity_n=-1 in field-period units = helicity_n=4 in SIMSOPT full-torus units
residuals_fn = vj.make_qh_residuals_fn(
    static, indata, helicity_m=1, helicity_n=-1,
    target_aspect=7.0, surfaces=np.arange(0, 1.01, 0.1),
)
opt    = vj.FixedBoundaryExactOptimizer(static, indata, boundary, specs, residuals_fn)
result = opt.run(params0, max_nfev=15, ftol=1e-3, gtol=1e-3, xtol=1e-3)

opt.save_wout("wout_final.nc", result["x"])
opt.save_history("history.json", result)

SIMSOPT + VMEC2000

from simsopt.mhd import Vmec, QuasisymmetryRatioResidual
from scipy.optimize import least_squares
import numpy as np

vmec = Vmec("input.nfp4_QH_warm_start", verbose=False)
vmec.run()

surf = vmec.boundary
surf.fix_all()
surf.fixed_range(mmin=0, mmax=2, nmin=-2, nmax=2, fixed=False)
surf.fix("rc(0,0)")

qs = QuasisymmetryRatioResidual(vmec, np.arange(0, 1.01, 0.1), helicity_m=1, helicity_n=4)

result = least_squares(lambda x: (surf.__setattr__('x', x) or qs.residuals()),
                       surf.x, method='lm', max_nfev=15,
                       ftol=1e-3, gtol=1e-3, xtol=1e-3)

The vmec_jax version is self-contained (no Fortran binary, no subprocess), runs in a single Python process, and produces a more accurate result in fewer effective evaluations.

Optional local SIMSOPT checks

SIMSOPT comparisons are optional integration checks, not required PR gates. They are intended for developers who have SIMSOPT installed locally and want to verify shared formulas or reproduce cross-backend optimization diagnostics.

Formula-level checks can be run with:

RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_simsopt_optional_validation.py
RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_redl_bootstrap_simsopt_parity.py
RUN_SIMSOPT_VALIDATION=1 pytest -q tests/test_finite_beta_helpers_unit.py::test_redl_bootstrap_formula_matches_simsopt_when_available

The dedicated SIMSOPT validation test is additionally gated by RUN_SIMSOPT_VALIDATION=1 so that required CI remains independent of a local SIMSOPT checkout. These tests use pytest.importorskip for SIMSOPT modules, so they skip when SIMSOPT is not installed. They may also skip if optional runtime dependencies such as jax or netCDF4 are unavailable.

The heavier optimization comparison script is local-only by default:

python examples/optimization/compare_omnigenity_qs_mode1.py

That script writes summaries under its configured output directory and catches SIMSOPT-side failures into a failure JSON so the vmec_jax leg can still be inspected. Do not put this workflow in required CI; if CI coverage is desired, run it from a scheduled/manual job with a pinned SIMSOPT environment and the VMEC2000 executable available through SIMSOPT.

Practical guidance: when to use which

Use vmec_jax when:

  • You need high-quality gradients (exact Jacobians) for sensitive optimization problems — e.g., near the optimum where FD errors matter.

  • You want GPU acceleration without code changes.

  • You want end-to-end differentiability through the optimizer (e.g., meta-learning, hyperparameter gradients).

  • The parameter space has many DOFs (exact Jacobian scales better than FD).

  • You prefer a self-contained Python install without Fortran dependencies.

Use SIMSOPT when:

  • You need access to SIMSOPT’s broader ecosystem: free-boundary, coil optimization (simsopt.field), bootstrap current targets, Boozer transforms, etc.

  • You want the fastest individual VMEC solve on CPU — the VMEC2000 Fortran binary is faster per iteration for small problems.

  • You need MPI parallelism for large finite-difference Jacobians (SIMSOPT parallelises FD columns across MPI workers; vmec_jax does not require MPI because the Jacobian is cheap).

References

[Helander2008]

Helander, P. and Simakov, A. N. (2008). Intrinsic ambipolarity and rotation in stellarators. Physical Review Letters, 101, 145003. https://doi.org/10.1103/PhysRevLett.101.145003

See also