Performance notes¶
This page describes the performance characteristics of vmec_jax, the
algorithmic and implementation choices that determine them, and practical
tuning advice.
Overview: cold vs warm runtime¶
vmec_jax uses XLA JIT compilation (via JAX). The first call in a
process compiles the iteration kernels; subsequent calls reuse the compiled
code:
Cold run: includes Python/JAX startup and XLA compilation (one-time cost per process). Depending on input size and whether the persistent cache is warm, typical fixed-boundary CPU cases range from a few seconds to tens of seconds.
Warm run: steady-state solve time after the kernels are in memory — the fair comparison against VMEC2000.
Persistent XLA compilation caching is enabled automatically for
accelerator-selected runs. CPU cache use is opt-in because XLA:CPU persistent
cache entries are native AOT executables and can emit host-feature mismatch
errors on some JAX versions. Compiled kernels are stored under
~/.cache/vmec_jax/jax_cache/<machine-fingerprint> unless the user sets
VMEC_JAX_COMPILATION_CACHE_DIR or upstream JAX_COMPILATION_CACHE_DIR.
The suffix includes host CPU details plus Python/JAX/JAXLIB versions to avoid
reusing native XLA:CPU AOT executables compiled for an incompatible runtime.
Set VMEC_JAX_COMPILATION_CACHE=1 to opt in for CPU runs, or
VMEC_JAX_COMPILATION_CACHE=0 to disable the persistent cache.
VMEC2000 is a pre-compiled Fortran binary with no JIT overhead — it is always
effectively “cold”. When benchmarking, compare vmec_jax warm runtime
against VMEC2000 runtime.
2026-05-25 rerun snapshot¶
The current post-performance-lane audit was rerun under
outputs/rerun_20260525_123334. This snapshot is intentionally a
measurement checkpoint, not a new broad speedup claim.
Parity gates:
Matrix |
Cases |
Failures |
Coverage |
|---|---|---|---|
|
6 |
0 |
fixed/free boundary, axisymmetry/non-axisymmetry, LASYM false/true |
|
1 |
0 |
QH warm-start fixed-boundary stage trace against VMEC2000 |
Representative profiler timings:
Wrapper time is the complete subprocess launch used by
gpu_cpu_performance_matrix.py. Profile time is the measured body of
the diagnostic report. For short fixed-boundary cases, the difference is
mostly Python/JAX startup, import, and first-use compilation overhead.
Case |
Backend |
Wrapper time |
Profile time |
Main exposed cost |
|---|---|---|---|---|
QH warm-start fixed-boundary |
CPU |
|
|
cold process/import/JIT dominates this short case |
QH warm-start fixed-boundary |
GPU |
|
|
accelerator launch/JIT overhead dominates this short case |
QH finite-beta fixed-boundary |
CPU |
|
|
finite-beta solve body |
QH finite-beta fixed-boundary |
GPU |
|
|
finite-beta solve body plus launch overhead |
LASYM pressure fixed-boundary |
CPU |
|
|
cold process/JIT dominates |
LASYM pressure fixed-boundary |
GPU |
|
|
GPU is faster for this short LASYM profile |
QH mode-2 exact dense Jacobian |
CPU |
|
|
replay/tangent construction |
QH mode-2 exact dense Jacobian |
GPU |
|
|
JVP-only tape plus replay/tangent construction |
QI Boozer/residual isolation |
CPU |
|
|
VMEC solve |
QI Boozer/residual isolation |
GPU |
|
|
VMEC solve |
Low-hanging finding from the rerun: the GPU exact-callback profile was already
using the JVP-only exact tape path, but the matrix summary reported
jvp_tape=False because it echoed only the requested CLI flag. The profiler
and report comparator now record requested and effective JVP/basepoint-carry
metadata separately and infer effective behavior from the profile timers. This
was a reporting fix rather than a solver-path change.
The exposed remaining bottlenecks are still:
exact dense-Jacobian accepted-point replay/tangent construction,
first-call Boozer/QI setup on GPU,
cold-process import/JIT cost for short CLI solves.
Key performance decisions¶
The following design choices explain the current performance profile. They reduce warm runtime and are especially useful inside differentiable optimization workflows, but the present single-solve CPU matrix should still be read as a mixed result rather than a broad VMEC2000 speedup claim:
- 1. Single-grid default for fixed-boundary
The default CLI path skips staged multi-grid schedules and goes directly to the final grid. VMEC2000 uses a
NS_ARRAYcontinuation schedule; for many cases the intermediate stages add overhead without improving convergence. Pass--parityto force the VMEC2000 continuation schedule.- 2. Backend-aware fixed-boundary default
Public CLI runs default to the VMEC-control non-scan loop on CPU because the auto-selected
acceleratedCPU policy can lose the scan win once VMEC-style progress printing and staged finish behavior are included. Quiet API/profiler CPU runs may select scan automatically for high-work decks, while low-mode CPU cases stay on the lower-latency host loop. Public GPU/CUDA/ROCm fixed-boundary runs default to thejax.lax.scanloop because the current May 2026officeprofiles show materially better warmed raw throughput on accelerator backends. Each scan chunk compiles to a single XLA program and returns only scalar convergence diagnostics to the host. Passuse_scan=Falseor profiler--no-use-scanto force the VMEC-control loop.- 3. Dynamic replay bucketing
VMEC_JAX_DYNAMIC_REPLAY_BUCKETpads nearby exact-adjoint tape lengths so compiled replay kernels can be reused across Jacobian calls with slightly different iteration counts. The default is backend-adaptive:32on CPU and128on CUDA/ROCm/GPU backends. Override it only as a profiling knob; larger buckets can still be slower if they pad a workload too aggressively.- 4. Preconditioner caching
The 1-D preconditioner (
clear_preconditioner_jit_caches) is JIT-compiled once and reused. Explicit cache clearing between Gauss-Newton Jacobian evaluations (post_jacobian_callback) releases accumulated caches without paying the cost of recompilation on the next call.- 5. Strict-update + no backtracking
The VMEC iteration algorithm uses
strict_update=True, backtracking=Falseto match the VMEC2000 step-accept path. Using backtracking causes the adaptive step size to collapse to machine epsilon on stiff 3D geometries (e.g. QH stellarators), resulting in non-convergence.- 6. FFT-based spectral synthesis
vmec_jaxuses FFT-based real-space synthesis (vmec_tomnsp.py) for the R/Z/λ transforms, replacing the matrix–vector approach of VMEC2000. This reduces transform cost from O(N_modes × N_real) to O(N_real log N_real).- 7. Relaxed trial residuals in optimisation line search
The Gauss-Newton line search uses a relaxed forward solve (fewer iterations, looser ftol) for trial step evaluations. The tight solve is only run for accepted steps and Jacobian builds. This roughly halves the per-iteration wall time in optimisation loops.
- 8. Discrete-adjoint Jacobian (solver-consistent, not finite differences)
For optimisation, the Jacobian is computed via discrete-adjoint replay (
build_residual_checkpoint_tape_direct+checkpoint_tape_state_jvp_columns). This gives a Jacobian consistent with the converged VMEC iteration tape without finite differences. The cost scales with the number of boundary degrees of freedom because all parameter tangent columns must be propagated through the tape.- 9. Cached quasisymmetry angular grids
QS optimisation callbacks reuse the same angular quadrature grid and trigonometric tables across residual and Jacobian calls. The public
quasisymmetry_ratio_residual_from_woutAPI still builds these arrays on demand, butmake_qs_residuals_fn/make_qh_residuals_fnprecompute them once for the fixed optimisation problem. This reduces cold residual/Jacobian trace overhead without changing the residual values.- 10. Minimal history in optimisation loops
The
light_history=Trueflag suppresses the full per-step diagnostic record during optimisation solves, reducing host/device traffic and memory pressure.- 11. On-disk XLA kernel cache
The persistent XLA compilation cache is enabled by default for repeated cold-process accelerator runs when the requested JAX platform is GPU/CUDA/ROCm/TPU or a public call explicitly requests
solver_device="gpu". Merely exposing a GPU throughCUDA_VISIBLE_DEVICESno longer enables the cache during import, so CPU runs on GPU workstations do not accidentally load CPU AOT cache entries. CPU cache use remains opt-in to avoid XLA:CPU AOT host-feature mismatch warnings. SetVMEC_JAX_COMPILATION_CACHE=1to enable it for CPU runs,VMEC_JAX_COMPILATION_CACHE=0to disable it, orVMEC_JAX_COMPILATION_CACHE_DIRto choose the cache location.- 12. GPU demand allocation
Before importing JAX, vmec_jax defaults
XLA_PYTHON_CLIENT_PREALLOCATE=falseunless the user already configured the allocator. This keeps GPU memory available for worker/profiling processes and improved the exact-Jacobian replay profile onoffice. SetXLA_PYTHON_CLIENT_PREALLOCATE=trueorVMEC_JAX_GPU_PREALLOCATE=1before import to keep JAX’s preallocation default.- 13. Tape exact default
Accepted-point exact optimizer callbacks use the discrete-adjoint tape path by default on both CPU and GPU. The scan exact path is still available through
VMEC_JAX_OPT_EXACT_PATH=scanfor targeted profiling and parity studies, but it is not the GPU default.- 14. Scan trial residuals
Relaxed trial residual solves in optimization loops use a backend-aware policy: CPU and current GPU/CUDA/ROCm exact-optimization runs stay on the VMEC-control non-scan path because high-mode trial-solve profiles showed the scan path paying too much cold compile/dispatch overhead. TPU keeps the scan default. Set
VMEC_JAX_OPT_TRIAL_SCAN=1or0for targeted diagnostics.- 15. Fused accelerator update step
The exact optimizer’s strict fixed-boundary accepted-point solve uses a cached JIT helper for the velocity/state update on non-CPU backends. This removes many small eager GPU dispatches per VMEC iteration while leaving the CPU host-update path unchanged. Set
VMEC_JAX_JIT_STRICT_UPDATE=0only for diagnostics.- 16. Residual-derived accepted-point history
For standard QS residual factories, accepted-point Jacobian callbacks reconstruct aspect-ratio and QS history metrics directly from the cached residual vector. This avoids a second accepted-point solve after every Jacobian callback. The conservative state-backed path is still used for custom residuals and histories that request an explicit
iota_fn.- 17. Device-resident exact-tape final state
The direct discrete-adjoint tape keeps the accepted-point final packed state as a JAX array until a public NumPy output boundary is reached. This avoids an unnecessary device-to-host transfer during exact Jacobian callbacks and is especially relevant for GPU replay profiling.
- 18. Async initial-state projection
The exact optimizer no longer blocks immediately after the JIT initial-state projection. The projected state is consumed by subsequent JAX solver work, so synchronizing there only adds a host/device barrier on accelerator runs. Set
VMEC_JAX_OPT_SYNC_INITIAL_STATE=1for diagnostics that need to attribute initial-state projection time independently.- 19. Experimental accelerator host-update assembly
The production non-scan GPU path is still limited by host/device scalar synchronization in VMEC-style time-step control. For lower-level diagnostics,
VMEC_JAX_HOST_UPDATE_ON_ACCELERATOR=1permits an explicitly requested host-update assembly path on GPU/accelerator backends. The public driver diagnostic overrideVMEC_JAX_HOST_UPDATE_ASSEMBLY=1remains CPU/non-scan only. This path is not a production default because it moves more intermediate data through the host and is case-dependent.- 20. GPU JVP-only exact tapes with basepoint carries
VMEC_JAX_OPT_JVP_ONLY_EXACT_TAPE=1and the diagnostics flag--jvp-only-exact-tapeomit reverse-mode base-carry storage from accepted exact tapes when only JVP columns are needed. As of the 2026-05-23 office GPU profile, accelerator exact callbacks enable this path by default and preserveVMEC_JAX_JVP_ONLY_EXACT_TAPE_BASEPOINT_CARRIESautomatically unless the user sets either environment variable explicitly. CPU keeps the established full-tape default.The production policy is deliberately narrower than the diagnostic knobs: JVP-only without basepoint carries is slower on GPU, and the forced scan exact path is still a profiling/parity tool rather than a default.
QH mode-2 accepted Jacobian profile, office, 2026-05-22¶Device
JVP-only
Total
Tape build
Replay
RSS growth
CPU
off
31.373 s
6.635 s
11.246 s
1934.6 MiB
CPU
on
28.199 s
3.432 s
18.724 s
2760.3 MiB
GPU
off
20.356 s
8.395 s
5.543 s
922.4 MiB
GPU
on
56.347 s
8.043 s
41.794 s
1378.2 MiB
GPU
on + basepoint carries
24.249 s
11.069 s
6.461 s
922.2 MiB
A 2026-05-23 bounded QH mode-2 exact-Jacobian profile on
officeshowed the production GPU lane clearly: full tape measured44.1 stotal, JVP-only without carries measured44.7 s, and JVP-only with basepoint carries measured15.1 s. CPU still uses the full tape by default.- 21. NumPy multigrid interpolation for ordinary stage transfers
Non-differentiated
NS_ARRAYstage transfers now use the host NumPy path for radial VMEC interpolation, while traced/autodiff calls stay on the JAX implementation. This avoids one-time XLA dispatch during ordinary multigrid startup without changing the interpolation convention or the final residual.- 22. NumPy initialization fast path for ordinary CPU solves
Public CPU performance-mode fixed-boundary runs now build non-traced initial guesses under the same NumPy compatibility shim used by the CPU force loop. Differentiated/traced calls and explicit JAX/GPU paths still use the JAX implementation. The pure-NumPy VMEC axis recompute scan is also vectorized. The CPU update path is now workload-sensitive: low-mode non-scan solves keep the host NumPy update path, while larger spectral/radial stages use the fused strict-update JIT path. On the local finite-beta QH multigrid profile this reduced solve-only wall time from
21.6 sto11.9 swhile keeping the converged residual at5.6e-13. SetVMEC_JAX_HOST_UPDATE_CPU_WORK_LIMITto tune the threshold.- 23. CPU host solves use an adaptive NumPy preconditioner apply
The non-scan CPU host-update path uses the pure-NumPy R/Z preconditioner apply for short solves (default threshold
VMEC_JAX_NUMPY_PRECOND_MAX_ITER=240) and for moderate/high spectral mode counts (default thresholdVMEC_JAX_NUMPY_PRECOND_MIN_MODES=16, with mode countmpol * (ntor + 1)). This avoids cold JAX preconditioner compilation in optimization trial/probe solves and also helps long CPU host solves when the R/Z preconditioner apply is bandwidth/dispatch dominated. Low-mode long solves keep the compiled preconditioner because it still amortizes better. On repeated local finite-beta QH multigrid profiles this reduced in-process wall time from the previous15.61 sbaseline to13.62--14.29 swithout a material regression in the small QH warm-start profile (1.64--1.67 s). SetVMEC_JAX_NUMPY_PRECOND_MAX_ITER=0andVMEC_JAX_NUMPY_PRECOND_MIN_MODES=0to disable this policy.- 24. Fixed-boundary startup avoids free-boundary imports
Ordinary fixed-boundary CLI/API runs no longer import or validate the free-boundary/NESTOR mgrid path. Free-boundary inputs still validate and prepare mgrids exactly when
LFREEBis active. This reducedimport vmec_jax.clistartup on the development Mac from about0.61 sto about0.36 safter the related public-API lazy import pass.- 25. Accelerator scan preflight uses a one-step cached runner
VMEC2000-style scan solves sometimes need a one-step preflight after an initial axis reset. On accelerator backends, running that step through the host/non-JIT path can dominate short LASYM profiles even when the scan kernel itself is fast. Quiet GPU/CUDA/ROCm/TPU scan runs now use the same cached one-step scan runner for preflight. CPU and verbose/printing paths keep the conservative host preflight. Set
VMEC_JAX_SCAN_JIT_PREFLIGHT=0or1to force either behavior during diagnostics.- 26. Multigrid profiles expose per-stage wall time
profile_fixed_boundary.pynow records multigrid stage wall time and per-stage solver-loop timing in diagnostics. Chunked accelerated stages are aggregated before reporting so stage wall time and VMEC-loop timing are comparable, and scan stages report theirscan_total_stiming through the same generic per-stage solve column as non-scan stages. On the local finite-beta QH case, the two CPU stages measured about5.1 sand20.5 swall time, while the aggregated final-stage solver-loop timing was about20.45 s. AVMEC_JAX_JIT_PRECOMPILE=0probe was slower (29.2 stotal), so the remaining cold-CPU target is compiled transform/preconditioner work inside the VMEC iteration loop plus shape-stable reuse, not WOUT output or driver-side staging overhead.
May 2026 policy validation snapshot¶
The detailed tables below combine the latest
outputs/rerun_20260525_123334 matrix with earlier May 2026 targeted probes.
Use the rerun snapshot above as the current end-to-end timing checkpoint; older
rows are retained as profiling provenance when they isolate a particular code
path. The main conclusion is unchanged: fixed-boundary GPU scan throughput can
be competitive on warmed/profiled stages, but full subprocess wall time and
exact-optimizer callbacks still have large startup, tape/replay, and tangent
construction costs.
Case |
Default warm |
Accelerated warm |
Warm speedup |
Final residual |
|---|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
The same finite-beta deck also passed the optional converged VMEC2000 external
parity test against ~/bin/xvmec2000 on the same checkout
(1 passed in 382 s), so the CPU policy change did not trade accuracy
for speed on this promoted reference.
Case |
Wall time |
Iterations |
Dominant exposed phase |
Final residual |
|---|---|---|---|---|
|
|
|
scan dispatch |
|
|
|
|
scan dispatch |
|
The finite-beta GPU run uses scan for the explicit-stage monitor chunks and
precomputed tridiagonal coefficients. This replaced the previous non-scan
stage-monitor profile (30.21 s wall time, dominated by 12.81 s in the
preconditioner) with a single cached scan runner hit and no scan-cache misses.
The profiled solve body is approximately VMEC2000-class for the promoted
finite-beta deck (~/bin/xvmec2000 measured 3.32 s on the same input in
the local external parity run), but the current subprocess wrapper still
measured 14.97 s because it includes startup/import/JIT. Force-kernel
tuning alone is no longer the first raw fixed-boundary GPU bottleneck for this
case.
2026-06-02 fixed-boundary and exact-callback checkpoint¶
After the fixed-boundary scan timing reporter was promoted, the same two
fixed-boundary references were rerun from clean timing subprocesses. The
--use-input-niter report stem now records input_niter instead of the
default diagnostic iters20 label, so these artifacts are easier to compare
against VMEC2000 runs that use the input deck budgets.
Case |
Backend |
vmec_jax scan body |
VMEC2000 total computational time |
Notes |
|---|---|---|---|---|
|
CPU |
|
|
450 iterations; warm scan body is VMEC2000-class |
|
GPU |
|
|
small warm solve remains GPU dispatch dominated |
|
CPU |
|
|
2459 iterations; close to VMEC2000, but still slower |
|
GPU |
|
|
2508 iterations; GPU dispatch/preconditioner launch overhead dominates |
The same checkpoint also reran the optimization-critical QH exact dense
Jacobian path on the office RTX A4000. The current default basepoint
dynamic replay remains faster than VMEC_JAX_DYNAMIC_REPLAY_MODE=whole_scan:
mode-2 warm callbacks measured about 3.47 s with basepoint replay versus
3.76 s with whole-scan replay, and mode-3 warm callbacks measured about
3.60 s versus 3.80 s. The remaining warm GPU cost is therefore not a
mode-selection issue. The repeated hotspots are accepted tape-build solve
dispatch, preconditioner apply, update-state dispatch, and projected
replay/residual-projection dispatch.
Policy |
Backend |
Wall time |
Scan total |
Preflight |
Device run |
Final residual |
|---|---|---|---|---|---|---|
host preflight |
CPU |
|
|
|
|
|
host preflight |
|
|
|
|
|
|
cached one-step preflight |
|
|
|
|
|
|
This profile is intentionally a short, non-converged raw throughput probe. It
shows why the accelerator preflight path matters: the actual GPU scan kernel
was already fast, but the old host preflight dominated the short LASYM run.
The patched one-step preflight keeps the same residual and cuts the GPU wall
time by about 3.9x for this diagnostic. The promoted converged QH decks
remain stable under the same policy: input.nfp4_QH_warm_start measured
0.689 s with final residual 1.11e-13, and
input.nfp4_QH_finite_beta measured 3.45 s with final residual
5.61e-13 on office.
Backend |
Wall time |
Tape build |
Tape replay |
Initial tangents |
Residual tangents |
|---|---|---|---|---|---|
CPU |
|
|
|
|
|
|
|
|
|
|
|
This callback has the same Jacobian shape and norm on CPU/GPU
(44353 x 24, norm 5.705e2), so the remaining gap is performance rather
than a different solve. The latest GPU run used the effective JVP-only exact
tape path with basepoint carries. The next optimizer-performance work should
target accepted-point tape build/replay and tangent construction; forcing scan
for this mode-2 dense callback is not the production fix. Projected replay is
now enabled for non-LASYM GPU callbacks with at least 24 columns because later
bounded profiles show a small mode-2 win, but the remaining bottleneck is still
accepted-point tape/replay dispatch rather than residual-projection kernel math.
Older same-point cache probes still show that repeated identical callbacks can
return from cache in milliseconds, so the production cost is at new accepted
optimizer points rather than redundant repeated evaluations.
Backend |
Wall time |
Tape build |
Tape replay |
Initial tangents |
Residual tangents |
|---|---|---|---|---|---|
CPU |
|
|
|
|
|
|
|
|
|
|
|
The perturbed profile uses --perturb-scale 1e-4 to mimic a new accepted
optimizer point. It has the same Jacobian shape and norm on CPU/GPU
(44353 x 24, norm 4.976e2). An earlier explicit
--solver-device gpu matrix took 53.18 s because the optimizer wrapped
every callback in a redundant jax.default_device context even though the
process backend was already GPU. The current policy treats explicit requests
for the already-active backend as a no-op, removing that overhead while still
preserving explicit CPU/GPU routing when the requested backend differs from the
active process backend.
Case |
Backend |
Wall time |
Tape build |
Replay path |
Replay time |
Initial tangents |
Residual tangents |
|---|---|---|---|---|---|---|---|
QH |
CPU |
|
|
standard |
|
|
|
QH |
|
|
|
projected |
|
|
|
QA |
CPU |
|
|
standard |
|
|
|
QA |
|
|
|
projected |
|
|
|
The mode-3 callbacks have 48 boundary parameters. GPU replay now uses
8-column chunks for this size and enables projected replay for
stellarator-symmetric callbacks. The 2026-05-30 office rerun supersedes
the older 24-column policy for these shapes.
Backend / policy |
Wall time |
Tape build |
Replay time |
Initial tangents |
Residual tangents |
|---|---|---|---|---|---|
CPU standard |
|
|
|
|
|
|
|
|
|
|
|
LASYM doubles the active boundary families, so the mode-2 callback also reaches
48 columns. It should not use the symmetric mode-3 projected-replay heuristic:
the measured LASYM projected path took 51.1 s total. Chunking remains the
right default for this tested LASYM matrix, and the current GPU chunk policy is
the same 8-column policy used for stellarator-symmetric mode-2/mode-3
callbacks. Explicit environment overrides are still available for profiling,
but the production automatic policy is now:
GPU callbacks with at least 24 columns use 8-column replay chunks.
Projected replay is only auto-enabled for stellarator-symmetric callbacks with at least 24 columns.
LASYM uses the conservative chunked standard replay path unless explicitly overridden.
CPU/GPU profiling playbook¶
Use the diagnostics scripts from the repository root. Always record whether a
timing is cold or warm, the selected JAX backend, JAX/JAXLIB versions, input
deck, max_mode or iteration budget, and whether the command is measuring
raw solver throughput or optimization callback overhead.
For sweep and matrix outputs, do not infer the runtime backend from an output
directory or --backend-label alone. Use the JSON/CSV provenance fields
(jax_backend, jax_device_kind, solver_device, jax_platforms)
and the profiler runtime block when comparing CPU/GPU results.
For optimization work, keep three measurements separate:
raw fixed-boundary throughput, measured with
profile_fixed_boundary.py;accepted-point exact optimizer callbacks, measured with
profile_exact_optimizer.pyorgpu_cpu_performance_matrix.py;QI Boozer/residual cost, measured with
profile_qi_boozer_gpu.pybefore attributing a slow QI run to VMEC replay or SciPy trust-region behavior.
For CPU-only timings, force a CPU process with JAX_PLATFORMS=cpu or an
explicit --solver-device cpu. For GPU timings, use
JAX_PLATFORM_NAME=gpu plus --solver-device gpu. On NVIDIA-only JAX
installs, JAX_PLATFORMS=cuda is also valid; avoid JAX_PLATFORMS=gpu on
mixed CUDA/ROCm installations because some JAX versions try to initialize both
backends.
Raw fixed-boundary throughput:
JAX_PLATFORMS=cpu JAX_ENABLE_X64=1 PYTHONPATH=. python tools/diagnostics/profile_fixed_boundary.py \
--input examples/data/input.nfp4_QH_warm_start \
--iters 20 \
--simple-profile \
--no-multigrid \
--no-auto-cli-policy \
--solver-mode accelerated \
--no-use-scan \
--solver-device cpu \
--json-out /tmp/vmec_jax_qh20_raw_cpu.json
JAX_PLATFORM_NAME=gpu JAX_ENABLE_X64=1 PYTHONPATH=. python tools/diagnostics/profile_fixed_boundary.py \
--input examples/data/input.nfp4_QH_warm_start \
--iters 20 \
--simple-profile \
--no-multigrid \
--no-auto-cli-policy \
--solver-mode accelerated \
--solver-device gpu \
--json-out /tmp/vmec_jax_qh20_raw_gpu.json
The profiler follows production scan policy when neither --use-scan nor
--no-use-scan is supplied. Use --use-scan only to force scan on a CPU
run or diagnostic comparison, and use --no-use-scan to force the
Python/VMEC-control loop on GPU.
Short exact-optimizer runs:
JAX_PLATFORMS=cpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 3 --max-nfev 2 --run-repeats 3 \
--inner-max-iter 120 --trial-max-iter 120 --solver-device cpu \
--json-out /tmp/vmec_jax_qh_m3_cpu.json
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 3 --max-nfev 2 --run-repeats 3 \
--inner-max-iter 120 --trial-max-iter 120 --solver-device gpu \
--vmec-timing --json-out /tmp/vmec_jax_qh_m3_gpu.json
Accepted-point Jacobian callback cost at realistic new optimizer points:
JAX_PLATFORMS=cpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--solver-device cpu --vmec-timing --json-out /tmp/qh_m2_cpu_jacobian.json
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--solver-device gpu --vmec-timing --json-out /tmp/qh_m2_gpu_jacobian.json
Compare the JSON reports before launching a full sweep or a long GPU run:
PYTHONPATH=. python tools/diagnostics/compare_profile_reports.py \
/tmp/qh_m2_cpu_jacobian.json /tmp/qh_m2_gpu_jacobian.json \
--label cpu --label gpu \
--json-out /tmp/qh_m2_cpu_gpu_comparison.json
The comparison summary exposes exact optimizer phases separately:
exact_tape_build_s, exact_tape_build_unattributed_s,
initial_tangents_s, initial_projection_s, residual_tangents_s,
trial_solve_s, and exact_solve_s. Use these fields before changing
solver kernels: in recent GPU runs the dominant cost has been accepted-point
tape/replay and tangent construction, not VMEC force assembly.
The same tool also prints an “Exact optimizer patch targets” section that
ignores enclosing timers such as jacobian_total and
exact_solve_with_tape_total. Treat that row as the next concrete profiling
target when a total timer is largest; it points at leaf-like accepted-point
work such as *_tape_replay, exact_tape_build_unattributed,
*_initial_vjp, or *_residual_tangents.
QI Boozer/residual isolation:
JAX_PLATFORMS=cpu PYTHONPATH=. python tools/diagnostics/profile_qi_boozer_gpu.py \
--solver-device cpu --repeat 2 --jit-booz \
--output results/diagnostics/qi_boozer_cpu.json
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_qi_boozer_gpu.py \
--solver-device gpu --repeat 2 --jit-booz \
--output results/diagnostics/qi_boozer_gpu.json
The user-facing QI optimization helpers default to the jitted Boozer path
(QuasiIsodynamicOptions(jit_booz=True)). Keep the non-jitted profiler mode
available for parity isolation; use --jit-booz for production-like QI
timings.
May 2026 QI Boozer/GPU split¶
On the bundled input.nfp2_QI low-resolution diagnostic
(mpol=ntor=mboz=nboz=3, one Boozer surface, repeat=3), --jit-booz
improved the QI/Boozer phase on both backends:
CPU: total profiler time
20.17 swithout Boozer JIT and18.73 swith Boozer JIT; first QI/Boozer call improved from8.38 sto6.78 s.officeGPU: total profiler time27.39 swithout Boozer JIT and20.88 swith Boozer JIT; first QI/Boozer call improved from11.01 sto4.98 s.
The remaining gap in this small case comes from the VMEC solve, not from the QI
residual itself. A raw 80-iteration fixed-boundary input.nfp2_QI profile
with the accelerated single-grid path measured 0.65 s on CPU and 3.11 s
on GPU; forcing the scan path measured 0.55 s on CPU and 3.35 s on GPU.
The next GPU work should therefore target fixed-boundary force/update kernel
launch and tape structure, while keeping Boozer JIT enabled for QI production
runs.
A bounded QH warm-start forced non-scan profile on 2026-05-23
(input.nfp4_QH_warm_start, 20 iterations, --no-warmup) measured
9.70 s total on local CPU and 7.48 s total on office GPU. The GPU
was faster end-to-end in that cold profile, but the VMEC phase timers still
showed slower accelerator micro-kernels: compute_forces was 1.03 s on
GPU versus 0.094 s on CPU, and update was 0.133 s on GPU versus
0.008 s on CPU. This is retained as historical forced-path evidence. The
current policy is narrower: ordinary GPU fixed-boundary solves use the scan path
by default, while exact accepted-point GPU Jacobians use the tape path with
JVP-only/basepoint carries enabled by default.
A follow-up LASYM=true finite-beta fixed-boundary profile on office
(input.basic_non_stellsym_pressure, 20 iterations, forced non-scan
diagnostic path, no warmup) measured 21.38 s profile time on CPU and
26.15 s on GPU. GPU force assembly was faster (1.29 s versus 2.31 s), but GPU
preconditioner time was slower (2.48 s versus 1.09 s), so the next raw
GPU solve target for this lane is preconditioner/update launch structure rather
than the force kernel alone. Forcing the scan path was slower on both devices
in the same case: 35.76 s CPU and 44.96 s GPU profile time.
The next office pass isolated the preconditioner piece: enabling
precomputed Thomas coefficients for the same historical raw LASYM GPU profile reduced
wrapper time from 30.50 s to 11.33 s and detailed preconditioner time
from 2.66 s to 0.74 s. Short VMEC2000 trace parity was then checked for
the LASYM finite-pressure case and the QH warm-start case. The public
production policy now enables this path automatically for non-scan,
performance-mode accelerator LASYM solves, and for higher-mode accelerator
non-LASYM solves where the mode count is large enough to amortize the
coefficient setup. Small non-LASYM decks keep the legacy default: the
2026-05-23 QH warm-start profile regressed from 13.4 s to 14.4 s when
precomputation was forced, while the 50-mode finite-beta QH profile improved
from 52.6 s to 49.4 s with the same iteration count and final residual.
Set VMEC_JAX_TRIDI_PRECOMPUTE=0 to disable this narrow accelerator default
during diagnostics, or =1 to force it in lower-level experiments.
The same commit was profiled with a one-callback QH mode-2 exact Jacobian
matrix on office (inner_max_iter=40, trial_max_iter=20,
--sync-replay-timing). The GPU profile was faster end-to-end than CPU
(15.45 s versus 29.88 s), with GPU accepted exact solve 7.20 s,
tape build 5.45 s, replay 6.16 s, and initial tangent construction
2.09 s. CPU spent 11.02 s in the exact solve, 7.16 s in tape build,
6.72 s in replay, 8.19 s in initial tangents, and 3.95 s in
residual tangents. This keeps the current GPU optimization target focused on
cold accepted-point tape/replay/tangent construction; forcing scan for GPU is
not justified by these profiles.
A matched run with --jvp-only-exact-tape --jvp-only-basepoint-carries was
nearly neutral: CPU profile time was 29.97 s and GPU profile time was
14.95 s. That is a small GPU win for this one callback, but not enough to
change the default without larger mode and full optimizer-trajectory coverage.
A 2026-05-23 GPU sidecar pass on office at 4d61eab used one QH mode-2
cold Jacobian callback (inner_max_iter=20, trial_max_iter=20,
--sync-replay-timing) plus matched trial and raw-LASYM probes. Full-tape
exact took 13.29 s on CPU and 36.81 s on GPU; the GPU regression was
the accepted replay dispatch bucket (21.10 s dispatch, 0.00 s ready),
not tangent construction (GPU initial/residual tangents 2.04 s/2.46 s).
The same GPU callback with
--jvp-only-exact-tape --jvp-only-basepoint-carries dropped to 14.48 s
and replay to 3.62 s, but the matched CPU JVP-only callback regressed to
27.86 s with 8.96 s in initial tangents. Treat this as a GPU-only
candidate that needs mode-3 and full-trajectory validation before changing
defaults.
The same sidecar confirmed the trial-scan policy. Forced scan trial solves
spent 6.76 s in the CPU scan block and 11.68 s in the GPU scan block;
the GPU device bucket was almost entirely dispatch/compile-like
(10.24 s dispatch, 0.001 s ready). Disabling scan reduced the trial
profiles to 4.46 s on CPU and 6.76 s on GPU. Later QH max_mode=2
profiles briefly favored a patched scan GPU default, but the 2026-06-02
office QH max_mode=4 optimizer trace reversed that policy: forced
non-scan trial solves reduced the short run from 89.23 s to 77.73 s and
the trial solve itself from 15.95 s to 4.56 s. Trial callbacks
therefore default to non-scan on CPU/GPU/CUDA/ROCm and scan only on TPU; use
VMEC_JAX_OPT_TRIAL_SCAN=1 or 0 to force either path.
For historical raw LASYM=true input.basic_non_stellsym_pressure non-scan solves
(20 iterations, single-grid, no warmup, detailed timing), the same pass
measured 7.98 s on CPU and 8.84 s on GPU. Force assembly was comparable
(1.29 s CPU, 1.27 s GPU), GPU update was faster (0.14 s versus
0.41 s), and GPU preconditioner apply remained slower (0.44 s versus
0.27 s). The current scan-lifted profile above has moved that short-case
target to preflight and scan-runner setup rather than force assembly alone.
A larger historical QH mode-3 exact-Jacobian callback on office showed
that the pre-JVP-only projected-replay residual path was the wrong default for
GPU. With projected replay enabled the one-callback profile took 69.06 s
wall time (63.53 s profile time, 33.40 s replay, 10.00 s
projected residual tangents). The same callback with projected replay
disabled took 21.59 s wall time (16.29 s profile time, 4.29 s
replay, 2.38 s residual tangents). Later JVP-only/basepoint-carry
profiles changed that conclusion for non-LASYM callbacks, so production GPU
exact callbacks now enable projected replay for stellarator-symmetric cases
with at least 24 columns. VMEC_JAX_OPT_PROJECTED_REPLAY_RESIDUALS remains
available as an explicit diagnostic override.
For raw input.nfp2_QI follow-up profiling, keep the production-like scan
measurement separate from phase attribution. The scan path is best inspected
with XProf traces because the force/preconditioner/update work is inside one
lax.scan program. The non-scan path supports lightweight JSON phase timing
with --vmec-timing:
PYTHONPATH=. python tools/diagnostics/gpu_cpu_performance_matrix.py \
--mode fixed-boundary --backend cpu --backend gpu --keep-going \
--input examples/data/input.nfp2_QI \
--iters 80 \
--solver-mode accelerated \
--single-grid \
--raw-solver-policy \
--use-scan \
--outdir outputs/performance_profiles/qi_raw_scan80
JAX_PLATFORMS=cpu JAX_ENABLE_X64=1 PYTHONPATH=. python tools/diagnostics/profile_fixed_boundary.py \
--input examples/data/input.nfp2_QI \
--iters 20 \
--simple-profile \
--no-warmup \
--no-multigrid \
--no-auto-cli-policy \
--solver-mode accelerated \
--solver-device cpu \
--vmec-timing \
--vmec-timing-detail \
--json-out outputs/performance_profiles/qi_raw_cpu_phase20.json
JAX_PLATFORM_NAME=gpu JAX_ENABLE_X64=1 PYTHONPATH=. python tools/diagnostics/profile_fixed_boundary.py \
--input examples/data/input.nfp2_QI \
--iters 20 \
--simple-profile \
--no-warmup \
--no-multigrid \
--no-auto-cli-policy \
--solver-mode accelerated \
--solver-device gpu \
--vmec-timing \
--vmec-timing-detail \
--json-out outputs/performance_profiles/qi_raw_gpu_phase20.json
The fixed-boundary raw-solve source path is:
vmec_jax/driver.py::run_fixed_boundary for accelerated/single-grid policy,
vmec_jax/solve.py::solve_fixed_boundary_residual_iter for the VMEC loop and
_run_vmec2000_scan scan runner, vmec_jax/solve.py::_compute_forces for
per-iteration force assembly and residual norms,
vmec_jax/vmec_forces.py::vmec_forces_rz_from_wout and
vmec_residual_internal_from_kernels for bcovar/force/tomnsps assembly, and
vmec_jax/vmec_tomnsp.py::tomnsps_rzl for the Fourier transform hot kernel.
The fixed-boundary profiler now keeps any diagnostics.timing block in its
JSON, and compare_profile_reports.py / gpu_cpu_performance_matrix.py
surface vmec_compute_forces_s, vmec_preconditioner_s, and
vmec_update_s when those timings exist.
The comparison table reports ratios for total runtime, compile/replay/cache
time when those timings exist, callback count, observed RSS peak, solve count,
accepted-point replay count, and cache growth. The JSON output is stable enough
for CI dashboards or follow-up scripts that track whether a GPU regression is
coming from tape replay, extra callbacks, cache retention, or a cold
compile-like phase.
It also emits a cheap bottleneck_hint for each report by selecting the
largest exposed phase timing (QI/Boozer first call, VMEC force assembly,
preconditioner/update, accepted-point replay, compile, or cache). Treat this
as a triage pointer only: it cannot identify work hidden inside a fused
lax.scan unless the source profile exported phase timings.
For repeatable CPU/GPU matrix launches, use the wrapper below. Its default
--backend auto does not set JAX_PLATFORMS or JAX_PLATFORM_NAME; the
child process inherits the user’s active JAX/GPU selection. Add explicit
--backend cpu or --backend gpu only when the comparison should launch
separate CPU/GPU processes.
PYTHONPATH=. python tools/diagnostics/gpu_cpu_performance_matrix.py \
--mode fixed-boundary \
--backend auto \
--input examples/data/input.nfp4_QH_warm_start \
--use-input-niter \
--public-cli-policy \
--allow-multigrid \
--outdir outputs/performance_profiles/qh20_auto
PYTHONPATH=. python tools/diagnostics/gpu_cpu_performance_matrix.py \
--mode exact-callback \
--backend cpu --backend gpu --keep-going \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--method scipy_matrix_free --dynamic-replay-mode whole_scan \
--vmec-timing \
--outdir outputs/performance_profiles/qh_m2_cpu_gpu
For fixed-boundary production-policy profiles, omit both --use-scan and
--no-use-scan so the child profiler follows backend-aware production scan
selection. Add --use-input-niter to use the input deck’s full
NITER/NITER_ARRAY schedule instead of the wrapper’s short --iters
budget.
The wrapper delegates to profile_fixed_boundary.py or
profile_exact_optimizer.py and writes one child JSON per backend plus a
matrix JSON. The printed table reports wrapper wall time, child profile wall
time, and replay time when the child report exposes it. The matrix JSON embeds
the normalized compare_profile_reports.py summary, so dashboards can track
total_runtime_s, replay_time_s, accepted_point_replay_count,
cache_entry_growth, and RSS peak without parsing profiler-specific output
shapes. Each child process also writes *.stdout.log and *.stderr.log
files next to its JSON report so failed GPU jobs can be diagnosed without
rerunning the matrix.
Use --dry-run before scheduling cluster jobs; it prints and records the
exact child commands and backend environment overrides without importing JAX in
the child. Use --replay-column-chunk or
--dynamic-replay-bucket to make tape/replay tuning explicit in the report.
Use --dynamic-replay-mode and --method to compare accepted-point replay
paths and dense/matrix-free optimizer behavior under the same child-launch
policy.
For cold exact-callback bucket audits, start with one perturbed callback and a
small iteration budget before launching larger GPU sweeps:
PYTHONPATH=. python tools/diagnostics/gpu_cpu_performance_matrix.py \
--mode exact-callback \
--backend cpu --backend gpu --keep-going \
--problem qh --max-mode 2 --callback jacobian --repeats 1 \
--perturb-scale 1e-4 --inner-max-iter 20 --trial-max-iter 20 \
--method scipy_matrix_free --vmec-timing --sync-replay-timing \
--outdir outputs/performance_profiles/qh_m2_cold_bucket_smoke
The matrix table surfaces exact_s, tape_build_s,
tape_unattr_s, replay_s, replay_dispatch_s,
replay_ready_s, init_tangent_s, resid_tangent_s, callbacks,
and replays when the child exact-callback report exposes those metrics.
When the child reports contain the detailed buckets, the wrapper also prints
separate Trial scan timing, Scan cache details, and
Projected replay / JVP details tables and embeds the same sections in the
matrix JSON. Use those sections to distinguish scan-cache misses from
projected-replay residual tangent cost before changing solver kernels. The
scan-cache diagnostic helpers also expose stable cache-miss categories
(iteration_budget, tolerance, scan_policy,
stage_transition, fallback_policy, etc.) so profiler summaries can
report why a trial solve missed the cached scan runner without relying on raw
tuple offsets.
Use --sync-replay-timing only for targeted cold-bucket diagnostics: it
adds block_until_ready synchronization so dispatch and device-ready buckets
are attributable, but that synchronization is not representative of production
sweep throughput.
Malformed VMEC_JAX_REPLAY_COLUMN_CHUNK values now fall back to the automatic
replay memory guard rather than aborting the Jacobian callback; set
VMEC_JAX_REPLAY_COLUMN_CHUNK=off or 0 only when chunking should be
disabled for a targeted profiling run. The optimizer-specific
VMEC_JAX_LASYM_REPLAY_COLUMN_CHUNK override follows the same safe parsing
policy, so malformed values fall back to the backend/input auto heuristic.
Use --trace-outdir for TensorBoard/XProf traces and
--device-memory-profile-out for JAX device-memory snapshots when GPU memory
or launch overhead is the bottleneck. Use --no-auto-cli-policy only when
you want raw solver throughput; omit it when measuring the public
run_fixed_boundary policy that users see through the CLI/API.
Exact optimizer profiling¶
Use tools/diagnostics/profile_exact_optimizer.py to time the exact
optimization callback stack for the QA, QH, and QP quasisymmetry objectives:
PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qa --max-mode 2 --max-nfev 2 \
--trial-max-iter 300 --trial-ftol 1e-10
Use --problem qp for the quasi-poloidal fixed-boundary profile preset. QI
uses Boozer-space residuals and promotion gates, so profile it with the
dedicated tools/diagnostics/profile_qi_boozer_gpu.py path before launching
full QI optimization sweeps.
The callback profile reports separate timings for relaxed trial solves, exact
tape construction, checkpoint-tape JVP replay, residual tangent projection, and
wout writing. The profiler does not compute initial aspect/QS metrics by
default because that requires an exact solve that is immediately cleared before
the measured callback/run; add --initial-metrics only when you want that
sanity check and do not need cold-start timing purity. On the current
exact-adjoint implementation, the dominant
term for max_mode=2 and max_mode=3 is
jacobian_tape_replay. Ordinary fixed-boundary solves can benefit from GPU
lax.scan after warmup. For exact optimization, accepted-point Jacobians use
the discrete-adjoint tape path on both CPU and GPU by default. May 2026
office RTX A4000 profiling showed that forced scan accepted-point Jacobians
are still useful for parity/profiling experiments but can be much slower than
tape on production-like cold GPU callbacks. Set
VMEC_JAX_OPT_EXACT_PATH=tape or VMEC_JAX_OPT_EXACT_PATH=scan to force
one accepted-point path for parity or profiling. Relaxed trial residuals are
backend-aware: CPU/GPU/CUDA/ROCm default to the VMEC-control non-scan loop, and
TPU defaults to scan. May 2026 bounded QH mode-2 profiles measured local CPU
non-scan trial callbacks at 3.64 s total for two repeats with sane
residuals; forced CPU scan was slower and produced less reliable finite-step
diagnostics. A 2026-06-02 office RTX A4000 QH mode-4 optimizer trace then
showed the non-scan GPU trial path reducing the trial solve from 15.95 s to
4.56 s. Set VMEC_JAX_OPT_TRIAL_SCAN=1 or 0 to force either path for
diagnostics.
solver_device=None, "auto", and "default" inherit JAX’s active
backend; pass solver_device="cpu" or "gpu" only when you want an explicit
override.
The tape implementation is backend-aware. CPU accepted-point Jacobians keep the
full-tape default. GPU/CUDA/ROCm accepted-point Jacobians default to JVP-only
tapes while preserving basepoint carries so the fast basepoint replay path
remains available. VMEC_JAX_OPT_JVP_ONLY_EXACT_TAPE=0 restores the full
GPU tape, =1 forces JVP-only on any backend, and
VMEC_JAX_JVP_ONLY_EXACT_TAPE_BASEPOINT_CARRIES=0 disables the GPU carry
default for profiling. The diagnostics wrappers expose the pair as
--jvp-only-exact-tape --jvp-only-basepoint-carries so CPU/GPU matrices can
record the exact environment without hand-written shell exports.
Representative May 2026 callback timings after the backend-adaptive replay bucket, scalar-gradient tangent-cache, and GPU replay-chunk changes were:
Case |
Device/path |
Budget |
Wall time |
|---|---|---|---|
QH |
|
|
|
QH |
same |
|
|
QH |
|
|
|
QH |
local CPU, tape exact |
|
|
QH |
local CPU, tape exact |
|
|
These short cases show that GPU is now competitive or faster on some cold
callbacks, but not uniformly faster for production-like mode-2 dense
Jacobians. The production default is therefore still the tape exact path on
GPU, while CPU remains the conservative recommendation for small dense
least-squares optimizations. Forced scan exact remains available through
VMEC_JAX_OPT_EXACT_PATH=scan for targeted diagnostics; a mode-2 forced-scan
GPU probe was stopped after the cold compile exceeded the practical profiling
budget, so it is not a production default. The remaining exact-callback
bottleneck is still split across accepted-point solve/tape construction,
checkpoint-tape replay, initial tangent construction, and residual tangent
projection; further speedups should target reuse/fusion across those stages
rather than raw force-kernel throughput alone.
For same-process warmup studies, repeat a callback at the same point or repeat the whole short optimizer run while keeping compiled executables warm:
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 2 \
--inner-max-iter 80 --trial-max-iter 40 --solver-device gpu
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 3 --max-nfev 2 --run-repeats 3 \
--inner-max-iter 120 --trial-max-iter 120 --solver-device gpu
For realistic accepted-point studies, perturb the parameter vector on each repeat. This keeps compiled helper shapes warm while forcing a new equilibrium tape/state at each point, matching the cost structure of a real optimizer trajectory more closely than same-point repeats:
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--solver-device gpu --vmec-timing --json-out qh_m2_gpu_new_points.json
With --vmec-timing, the callback profile also splits
exact_tape_build into solver compute-force, preconditioner, update, and
unattributed tape-building overhead terms. Add --vmec-timing-detail when
the preconditioner bucket is the bottleneck; it further reports
exact_tape_solver_preconditioner_apply and
exact_tape_solver_preconditioner_mode_scale. The detailed mode adds extra
synchronization and should be used for targeted diagnostics, not production
sweeps.
The replay diagnostics also include
replay_jvp_columns_* counters for the accepted-point tangent replay path:
which path ran (identity, dynamic-basepoint scan, segmented dynamic-basepoint
scan, dynamic linearization, dynamic-scan linearization, generic per-trace, or
generic scan), how many tangent-column leaf calls were replayed, how many input
columns they covered, and whether column chunking split the callback. This
keeps GPU profiling focused on accepted-point replay/tangent construction
instead of conflating scan-cache misses, replay chunks, and residual projection
in one aggregate callback timer.
Add --sync-replay-timing when the question is whether a cold callback is
spending time in replay dispatch/compile-like overhead or in device-ready
execution. This exposes *_tape_replay_dispatch,
*_tape_replay_ready, *_initial_tangents_vmap_dispatch, and
*_initial_tangents_vmap_ready buckets in the JSON profile. Dense
Jacobian residual projection now reports the same split through
jacobian_residual_tangents_dispatch and
jacobian_residual_tangents_ready; projected replay reports
jacobian_projected_replay_residual_tangents_dispatch and
jacobian_projected_replay_residual_tangents_ready. Keep this option off
for normal CPU/GPU comparison sweeps because the explicit synchronization
changes the measured workload.
A 2026-06-01 office rerun on RTX A4000/JAX 0.6.2 rechecked QH
max_mode=2 dense exact-Jacobian callbacks with perturbed accepted points,
--inner-max-iter 80, and single-GPU sequential timing. For non-LASYM
projected replay, larger chunks were faster than the older conservative
8-column policy: two callbacks took 81.0 s with chunk 24 versus 83.5 s
with chunk 8, and the replay-dispatch bucket dropped from 30.5 s to
25.1 s. Additional cold checks gave 73.2 s for QH max_mode=3 with
chunk 48 versus 81.4 s with chunk 8, and 69.7 s for a reduced
max_mode=4 check with chunk 40 versus 74.0 s with chunk 8. The default
GPU policy therefore keeps LASYM at conservative 8-column chunks, but uses
larger bounded chunks for non-LASYM projected replay. The remaining
bottlenecks are accepted-tape build, replay dispatch/compile-like overhead,
and dense residual-tangent projection.
A 2026-06-02 follow-up on QH max_mode=4 rechecked smaller and disabled
non-LASYM replay chunks after the GPU trial-solve policy change. With the
same inner_max_iter=80 exact-Jacobian callback budget, chunk 20 took
69.4 s and no explicit chunking took 71.0 s. Both were slower than the
current bounded default policy for this case, so the production heuristic was
left unchanged. The actionable GPU target was therefore reducing projected
replay dispatch and residual-tangent projection overhead, not shrinking the
replay chunk size further.
The same mode-4 callback also ruled out two simple replay-mode promotions.
Opting into the fused replay/projection helper with chunking disabled took
69.9 s and spent 38.6 s inside fused projected replay, so it did not
improve on the default bounded chunk policy. Forcing
VMEC_JAX_DYNAMIC_REPLAY_MODE=whole_scan was worse: 81.5 s total with
50.0 s in projected replay. Keep both modes as diagnostics rather than
defaults for high-mode QH until a future implementation changes the underlying
dispatch structure.
The next source change kept the same bounded replay chunks but changed how
chunked projected replay is assembled. For accelerator non-LASYM callbacks
where the parameter count exceeds the replay chunk size, vmec-jax now
projects residual tangents immediately after each replay chunk and concatenates
the resulting Jacobian blocks. This avoids materializing one full
n_params x state_size tangent block before residual projection. A QH
max_mode=4 callback on office with inner_max_iter=80 dropped from
the previous ~69--71 s range to 61.5 s by default; an explicit opt-in
probe of the same path measured 60.1 s. The new timing bucket is
jacobian_chunked_projected_replay_projection_total. Set
VMEC_JAX_OPT_CHUNKED_PROJECTED_REPLAY_PROJECTION=0 to restore the old
full-tangent projection path for diagnostics. A short QH max_mode=4
optimizer run with max_nfev=3 then confirmed the callback result at the
solver level: wall time dropped from the previous dense SciPy 94.8 s trace
to 78.4 s while reaching the same final objective scale
(1.74024e-1). The remaining high-mode GPU buckets are chunked
replay/projection, accepted exact-tape build, and relaxed trial solves.
The same diagnostics now cover QP. On office at commit 80ccd51, QP
max_mode=4 with 80 boundary DOFs took 63.8 s for one exact Jacobian
callback, with 31.9 s in
jacobian_chunked_projected_replay_projection_total and 22.5 s in tape
build. A short QP max_mode=4 optimizer run with max_nfev=3 took
70.5 s for one accepted Jacobian plus two trial solves. QP max_mode=5
with 120 boundary DOFs took 68.2 s for one exact Jacobian, with 36.2 s
in chunked projected replay/projection and 22.8 s in tape build. The
mode-5 result confirms that the current chunked projection path scales
reasonably with parameter count; the next production target remains reducing
the replay/projection dispatch and cold accepted-tape build costs.
The same QP max_mode=4 short optimizer was then profiled with
method="auto_scalar". With cost-only scalar trial filtering enabled it took
90.6 s and reduced the objective to 6.32; with cost-only trials disabled
it took 78.1 s and reduced the objective to 6.55. Both scalar-adjoint
runs made more progress than the dense three-evaluation diagnostic, but the
no-cost-only variant removed three extra forward trial solves with only a small
objective penalty. Therefore auto_scalar now leaves
scalar_cost_only_trials disabled by default and keeps it as an explicit
globalization knob.
A follow-up QH max_mode=2 GPU profile on office with
--inner-max-iter 80, --trial-max-iter 40,
VMEC_JAX_DYNAMIC_REPLAY_MODE=basepoint,
VMEC_JAX_REPLAY_COLUMN_CHUNK=0, and --sync-replay-timing confirmed
that the standard dense residual projection is dispatch/tracing dominated:
three perturbed exact-Jacobian callbacks took 13.04 s, 3.98 s, and
3.27 s. The cumulative jacobian_residual_tangents bucket was
2.263 s, split into 2.204 s dispatch and only 0.059 s
device-ready time. The next GPU exact-callback optimization should therefore
reduce Python/JAX dispatch and callback construction around accepted-point
replay/projection before spending effort on residual-projection kernel math.
A 2026-06-02 office profile on the same RTX A4000/JAX 0.6.2 setup then
isolated the accepted-tape preconditioner cost. For QH max_mode=2 with
24 boundary DOFs, three perturbed GPU exact-Jacobian callbacks dropped from
68.16 s to 65.04 s when
VMEC_JAX_OPT_EXACT_TRIDI_PRECOMPUTE_MAX_DOFS=24 was enabled; the first warm
callback dropped from 4.39 s to 3.63 s. For QH max_mode=3 with
48 DOFs and a 60-iteration budget, two callbacks dropped from 64.98 s to
63.70 s with the threshold set to 48; the warm callback dropped from
3.57 s to 3.38 s. The production default therefore precomputes tridi
coefficients for accelerator exact tapes up to 48 DOFs and leaves larger
parameter spaces on the legacy solver policy unless explicitly overridden.
The 2026-05-28 mode-2 GPU policy check also compared regular projected replay
against the fused replay/projection helper. Regular projected replay took
12.66 s for one QH mode-2 exact-Jacobian callback
(jacobian_projected_replay_total=6.16 s), while the fused helper took
13.38 s (jacobian_fused_projected_replay_total=6.25 s). Fusion is
therefore opt-in through VMEC_JAX_OPT_FUSED_PROJECTED_REPLAY=1; the default
uses regular projected replay for non-LASYM GPU mode-2+ cases.
For production cache-growth audits, use the same accepted-point callback mode
with JSON output and explicit budgets. The report records per-repeat phase
deltas, optimizer/global JIT cache cardinalities before and after each repeat,
RSS growth, cumulative callback profile, and a budget_status block. A
budget breach exits with status 2 by default; pass --budget-action warn
to keep the run informational:
JAX_PLATFORMS=cpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--solver-device cpu --vmec-timing \
--budget-total-wall-s 45 --budget-repeat-wall-s 20 \
--budget-tape-build-wall-s 20 --budget-replay-wall-s 15 \
--budget-residual-tangent-wall-s 10 --budget-accepted-replays 3 \
--budget-cache-entry-growth 12 --budget-rss-growth-mb 1024 \
--json-out /tmp/qh_m2_cpu_callback_cache.json
JAX_PLATFORM_NAME=gpu PYTHONPATH=. python tools/diagnostics/profile_exact_optimizer.py \
--problem qh --max-mode 2 --callback jacobian --repeats 3 \
--perturb-scale 1e-4 --inner-max-iter 80 --trial-max-iter 40 \
--solver-device gpu --vmec-timing \
--budget-total-wall-s 90 --budget-repeat-wall-s 45 \
--budget-tape-build-wall-s 45 --budget-replay-wall-s 30 \
--budget-residual-tangent-wall-s 20 --budget-accepted-replays 3 \
--budget-cache-entry-growth 12 --budget-rss-growth-mb 4096 \
--json-out /tmp/qh_m2_gpu_callback_cache.json
Use --callback accepted when you only want the accepted-point residual/tape
build without dense Jacobian replay. Keep --clear-between-repeats off for
cache-growth audits; enabling it intentionally drops optimizer/JIT caches
between repeats and measures cold callback behavior instead. Malformed
VMEC_JAX_DYNAMIC_REPLAY_BUCKET values fall back to the backend-adaptive
default instead of forcing the CPU-sized bucket on GPU diagnostics. The
--budget-tape-build-wall-s, --budget-replay-wall-s,
--budget-residual-tangent-wall-s, and --budget-accepted-replays limits
are intended for regression guards around the accepted-point tape/replay lane:
they catch extra tape replays or dense residual-tangent projection regressions
even when total callback wall time is noisy.
For the standalone sweep scripts, worker subprocesses also inherit the parent
JAX backend by default. Use JAX_PLATFORMS=cpu or
--worker-jax-platforms cpu only when an explicit CPU-only worker process is
desired.
After each accepted-point Jacobian, the optimizer drops the heavy adjoint tape
but keeps a single solved-state cache entry. This keeps RSS bounded while
avoiding a duplicate exact VMEC replay for final objective evaluation and
wout writing. The SciPy callback path also reuses this solved-state cache
when it asks for a residual at a point whose exact Jacobian was just built,
avoiding an otherwise unnecessary relaxed forward replay. If
save_wout(..., state=result["_state_final"]) is used immediately after
run(), no additional equilibrium solve is performed.
The tape Jacobian callback also returns the residual primal from the same
jax.linearize used for residual tangent projection. History recording and
the custom Gauss-Newton gradient path reuse that residual instead of evaluating
the accepted-point residual block a second time after every dense exact
Jacobian.
SciPy can also revisit the same rejected trust-region trial point. The
optimizer therefore keeps a small residual-only LRU cache for relaxed trial
callbacks. The cache stores NumPy residual vectors, not VMEC states or JAX
tapes, so it removes repeated forward solves without retaining large XLA
buffers. To audit the exact callback sequence, run the diagnostic script with
--trace-callbacks; the JSON history then includes a callback_trace block
that labels each residual/Jacobian callback as an exact-state cache hit, a
trial-residual cache hit, a fresh trial solve, or an exact tape replay.
The lowest accepted-point callback count currently comes from the custom
method="gauss_newton" loop: it uses relaxed residuals only for line-search
trial points and one exact Jacobian per accepted point. The SciPy trust-region
path remains more robust for the documented full QA/QH continuation examples,
but it may request additional Jacobians around trust-region updates. A bounded
CPU diagnostic on QH max_mode=1 with inner_max_iter=trial_max_iter=80
and max_nfev=2 gave:
Method |
Wall time |
Jacobian calls |
Notes |
|---|---|---|---|
|
|
2 |
one cached exact-state residual hit |
|
|
1 |
same first accepted step, fewer exact callbacks |
GPU exact-optimizer diagnostics¶
The current GPU bottleneck is the exact Jacobian replay path, not ordinary fixed-boundary force evaluation. The replay is a long sequence of linearized VMEC iteration steps, and cold GPU processes pay heavy XLA compile costs. For GPU profiling, always separate the first run from cache-warm runs:
JAX_PLATFORM_NAME=gpu python tools/diagnostics/profile_exact_optimizer.py \
--problem qa --max-mode 1 --inner-max-iter 20 \
--trial-max-iter 20 --solver-device default --max-nfev 1 \
--trace-callbacks --json-out gpu_trace.json
April 2026 callback diagnostics for the full input-deck QH max_mode=1 case
show where the GPU path loses today. Local CPU used JAX 0.9.2 on an Apple
workstation; office used JAX 0.6.2 on an NVIDIA RTX A4000 host.
Callback / process |
Wall time |
Backend |
Dominant terms |
|---|---|---|---|
Trial residual, local CPU |
|
CPU |
trial solve |
Trial residual, forced GPU |
|
GPU |
trial solve |
Exact residual/tape, local CPU |
|
CPU |
tape build |
Exact residual/tape, forced GPU |
|
GPU |
tape build |
Exact residual/tape, GPU process with CPU device |
|
GPU process, CPU device |
tape build |
Exact residual/tape, CPU-only process on |
|
CPU |
tape build |
Dense Jacobian, local CPU |
|
CPU |
tape build |
Dense Jacobian, forced GPU |
|
GPU |
tape build |
Dense Jacobian, GPU with demand allocation, inner-10 QH smoke |
|
GPU |
replay |
Full QH |
|
GPU |
|
This makes the present conclusion narrower and more actionable: the exact
optimizer is not GPU-ready just because the fixed-boundary force kernels are in
JAX. The bottleneck is the accepted-point tape build/replay path. A CPU
jax.default_device context inside a GPU-initialized process is still much
slower than a CPU-only process, but vmec_jax does not force CPU execution for
GPU-enabled users. Use explicit CPU-only workers for controlled CPU studies,
and explicit GPU backends for GPU profiling.
A follow-up QH max_mode=2 GPU trace with max_nfev=4 and
inner_max_iter=trial_max_iter=120 showed the same bottleneck. The warm
run spent about 74 s total: four accepted-point Jacobian callbacks consumed
about 49 s end-to-end, while three relaxed trial solves consumed about
24 s. The trace contained one exact-state residual cache hit and no
repeated trial residuals, so the next GPU optimization lane is reducing
accepted-point tape build/replay cost rather than adding more residual caching.
The perturbed accepted-point profiler separates same-tape cache hits from real
new optimizer points. On office with QH max_mode=2,
inner_max_iter=80, trial_max_iter=40, and --perturb-scale 1e-4, the
default tape path gave three GPU dense-Jacobian callbacks of about
13.8 s, 7.8 s, and 6.9 s. The mean profile was dominated by exact
tape construction (5.4 s per point), with tape replay around 2.1 s per
point. The same CPU run gave about 9.3 s, 5.8 s, and 5.6 s. The
scan exact path remained unsuitable for this workload: the first perturbed GPU
scan Jacobian took about 118 s.
After adding solver-phase timing to the same perturbed callback, the GPU
accepted-point tape build split into about 3.0 s of VMEC update work,
1.3 s of preconditioner work, 1.4 s of unattributed trace/build
overhead, and only 0.10 s of force evaluation per new point. The matching
CPU tape build was about 2.0 s per point and was instead dominated by force
evaluation (1.26 s per point). This identifies the next GPU target more
precisely: reduce host-dispatched update/preconditioner/tape bookkeeping and
replay overhead, not the already-fast GPU force kernels.
The first mitigation keeps dynamic exact-tape trace arrays on device until the
compact replay payload is assembled. On the warm office GPU QH
max_mode=2 profile above, that reduced mean tape-build time from about
5.75 s to 5.32 s per accepted point and update bookkeeping from about
3.02 s to 2.70 s. It does not solve the main GPU gap; it narrows the
next target to the update/preconditioner replay graph itself.
Splitting the update timer confirmed that conclusion: on the same QH
max_mode=2 GPU profile, exact_tape_solver_update_state accounted for
essentially all of exact_tape_solver_update (about 2.7 s per accepted
point), while trace-build/finalize bookkeeping was below 1 ms per point.
The next implementation target is therefore fusing or scanning the primal
state-update/replay work, not further reducing Python trace dictionary overhead.
The next production patch skips the update_rms reduction when it is not
consumed. Exact optimizer dynamic-tape solves run with light history, no update
clipping, and non-verbose output, so that reduction was wasted work. On the QH
max_mode=2 GPU profile, this reduced mean tape-build time further to about
4.74 s per accepted point and update-state time to about 2.30 s. The
largest remaining per-point GPU costs are now replay (about 2.67 s),
state update (about 2.30 s), and preconditioner work (about 1.10 s).
Dynamic replay payload stacking is backend-aware: GPU uses on-device JAX stacks
to avoid unnecessary host materialization, while CPU keeps the lower-overhead
NumPy stack path. This avoids the CPU replay regression seen with unconditional
device stacking.
The next accepted-point GPU fix fuses the strict fixed-boundary velocity/state
update into one cached JIT helper for accelerator backends. On office for
the same QH max_mode=2 perturbed dense-Jacobian profile
(inner_max_iter=80, trial_max_iter=40), two new accepted-point
callbacks dropped from roughly 20--22 s to 16.5 s. The measured
solver-update component dropped from about 2.6 s per accepted point to
about 0.22 s per accepted point; tape replay is now again the largest
remaining GPU term. Explicit replay column chunks of 8 or 12 were much slower
on this 24-DOF case because they segmented the replay into multiple GPU
launch/compile groups; full-column replay remained best.
The next smaller accepted-point cleanup caches the initial-state tangent matrix
for each VMEC theta-flip branch. With the magnetic-axis branch frozen by the
accepted-point tape, the initial guess is affine in the boundary coefficients,
so repeated Jacobian callbacks do not need to re-linearize that graph unless
the discrete flip branch changes. On the same QH max_mode=2 GPU profile,
three perturbed dense-Jacobian callbacks moved from about 20.7 s total to
about 19.9 s total, with matching Jacobian norms. Two attempted GPU replay
shortcuts were rejected as broad defaults at that point: precomputed
tridiagonal coefficients were correctness-tested but workload-dependent, and
stopping gradients through solver time-control scalars nearly doubled replay
time.
May 2026 follow-up profiling used Python 3.11.15, JAX 0.10.0, and
jax[cuda13] on the same office RTX A4000 host. The GPU backend is
working, but it is still not a clear win for these small/medium exact
optimization callbacks. The important split is raw fixed-boundary throughput
versus accepted-point optimization replay:
Callback / case |
CPU warm |
GPU warm |
Dominant GPU terms |
|---|---|---|---|
Raw QH fixed-boundary, 100 iterations |
|
|
fixed launch/compile overhead dominates |
QA |
|
|
replay and residual tangent projection |
QA |
|
|
cached scalar cotangent and replay |
QA |
|
|
replay and residual tangent projection |
QA |
|
|
cached scalar cotangent and replay |
Those numbers are perturbed accepted-point repeats with warm compiled helper
shapes, inner_max_iter of 40–80, relaxed ftol appropriate for
profiling, and the default tape exact path unless explicitly marked otherwise.
They are not full production optimization timings. The scalar reverse-adjoint
gradient now uses a cached JIT scalar-objective cotangent hook
for the built-in QS residual factories rather than VJP-ing the full residual
vector from Python on every callback. That reduced the QA max_mode=1 GPU
gradient callback from about 9.8 s to 4.6 s after warmup, and reduced
the QA max_mode=3 GPU gradient callback from about 5.8 s to 3.0 s.
Dense Jacobians remain competitive at low DOF counts; the scalar-adjoint path is
now the better candidate for higher-mode or memory-limited optimizations.
After caching the fixed quasisymmetry angular quadrature grid, the same
office QH max_mode=2 accepted-point callback profile gave two perturbed
GPU dense-Jacobian callbacks of about 11.8 s and 4.3 s with the default
dynamic replay bucket. Forcing VMEC_JAX_DYNAMIC_REPLAY_BUCKET=1024 made the
same profile much worse (about 53.9 s and 7.2 s). The production
optimizer therefore no longer sets a coarse replay bucket automatically; use
large buckets only for controlled experiments on workloads where recompilation
dominates replay execution.
The next optimization keeps a cached JIT residual evaluator for
non-differentiating optimizer callbacks and avoids recomputing the QS objective
when the residual vector is already available. On the same office GPU QH
max_mode=2 short production path (max_nfev=2), wall time dropped from
about 15.3 s to 11.2 s. This does not change the discrete-adjoint
Jacobian path, which still linearizes the raw residual function where
derivatives are required.
The accepted-point exact tape path now enables precomputed Thomas coefficients
for small-DOF accelerator tapes only. This is intentionally not applied to
trial solves by default: the same switch made cold GPU trial scans slower. On
office with JAX 0.6.2 and one RTX A4000, two perturbed QA max_mode=1
dense Jacobian callbacks dropped from 88.9 s total to 72.6 s total,
while the Jacobian Frobenius norms matched to about 7.5e-11 relative
difference. A later QH max_mode=2 profile with 24 boundary DOFs showed the
opposite tradeoff: replaying the larger tape outweighed the preconditioner
savings. The default therefore enables this optimization only up to 12
optimization DOFs. Set VMEC_JAX_OPT_EXACT_TRIDI_PRECOMPUTE=0 to disable
this accepted-point optimization for diagnostics, =1 to force it on a
specific backend, or VMEC_JAX_OPT_EXACT_TRIDI_PRECOMPUTE_MAX_DOFS to adjust
the automatic small-DOF threshold.
The 2026-05-23 TensorBoard/Perfetto pass on office confirmed the current
GPU exact policy. A cold QH max_mode=2 accepted Jacobian
(inner_max_iter=80, trial_max_iter=30, 24 DOFs) took about 43.4 s:
19.6 s in tape replay, 10.1 s in tape build, and 10.1 s in
residual tangent projection, with about 1.1 GiB RSS growth. Repeating after
the persistent cache was warm reduced the same callback to about 15.6 s
(5.7 s tape build, 4.0 s replay, 2.3 s residual tangents). Forcing
VMEC_JAX_OPT_EXACT_PATH=scan on the same callback took about 100.5 s.
This keeps the production optimizer on the tape exact path for GPU; scan exact
is only a parity/profiling fallback.
The office TensorBoard profiler may emit
Can't import tensorflow.python.profiler.trace when TensorFlow’s Python hook
is not installed. That warning is non-fatal for these runs: the XPlane trace,
Perfetto/trace JSON, and device-memory profile still write successfully.
Fixed-boundary GPU diagnostics¶
For fixed-boundary inputs, the GPU force kernels are fast, but the best solver
policy depends on whether the workload is a raw fixed-boundary solve or an exact
optimization callback. A 2026-05-23 office RTX A4000 profile showed the raw
fixed-boundary scan path is clearly faster for input.nfp4_QH_warm_start:
500 GPU iterations took about 1.37 s with scan versus 8.01 s with
the non-scan VMEC-control loop. Public auto-selected GPU/CUDA/ROCm
fixed-boundary runs therefore default to scan; CPU remains non-scan by default,
and users can still pass use_scan=False to force the VMEC-control loop.
A deeper 2026-05-23 trace pass separated cold compilation/cache effects from
warmed raw iteration throughput. On office with JAX 0.6.2 and an RTX
A4000, input.nfp4_QH_warm_start took about 13.3 s in a first cold GPU
scan process and 13.6 s in a first cold GPU non-scan process. After the
persistent compilation cache was warm, a fresh scan process dropped to about
5.7 s. In one long-lived process, the warmed scan run took about 1.4 s
whereas warmed non-scan took about 8.5 s. For
input.nfp4_QH_finite_beta, warmed scan took about 2.7 s for the
profiled iteration budget while warmed non-scan took about 12.6 s. The
non-scan finite-beta GPU phase timer showed force assembly was fast
(~2.5 ms per iteration) but preconditioner/update launch overhead dominated
(~10.4 ms per iteration in the preconditioner path), so production raw GPU
runs should keep the scan default.
A later 2026-05-23 office pass removed per-iteration host placeholder
allocations from the accelerator R/Z preconditioner-apply wrapper. On the
same RTX A4000 stack, input.nfp4_QH_warm_start dropped from about
22.9 s to 13.9 s with identical convergence, while
input.nfp4_QH_finite_beta dropped from about 93.0 s to 58.6 s.
The follow-up fused preconditioner payload pass combined accelerator-side
lambda scaling, mode weighting, and preconditioned residual diagnostics. In a
fresh no-warmup profile this kept the warm-start case at about 13.5 s and
reduced input.nfp4_QH_finite_beta to about 50.2 s. Enabling the
precomputed Thomas path only for high-mode non-LASYM accelerator cases then
kept the low-mode QH warm-start case on its legacy path while reducing the
finite-beta QH default profile to about 50.7 s. The detailed timer
still attributes most remaining GPU time to precond_apply (about 16.2 s
on the finite-beta final stage), so the next GPU target remains restructuring
the radial tridiagonal preconditioner itself rather than the already-fast force
assembly kernels.
The 2026-05-24 follow-up batched the precomputed R/Z Thomas systems and then
combined the m=0 and m>0 radial blocks through one padded scan. On the
same office GPU, the converged finite-beta scan profile measured 31.6 s
cold, 19.8 s for a cache-population fresh process, 8.12 s for the next
fresh process, and 3.39 s when warmed in the same process. The low-mode QH
warm-start case measured 5.24 s in a cache-warm fresh process and 0.71 s
when warmed in the same process. These numbers move the production GPU lane
from raw kernel throughput to cache warmup and compile-shape stability.
Earlier converged QH non-scan comparison, max_iter=500:
Backend |
Policy |
Wall time |
Final total residual |
|---|---|---|---|
CPU |
default, non-scan |
|
|
GPU |
accelerated, forced non-scan |
|
|
Older GPU scan/non-scan comparison, same host, retained as a reminder that full-convergence policy changes must be re-profiled case by case:
Case |
Scan |
Non-scan |
Result |
|---|---|---|---|
|
|
|
non-scan faster |
|
|
|
non-scan faster |
|
|
|
non-scan faster |
|
|
|
non-scan faster |
The finish policy also distinguishes production input-deck runs from
explicit low-budget diagnostics. When the user explicitly supplies
max_iter, all accelerated/parity finish attempts combined are capped at
twice that budget and the run reports non-convergence if the cap is exhausted.
This avoids spending many hidden extra iteration blocks in profiling or sweep
scripts. When max_iter is not overridden, VMEC input-deck budgets retain
the robust finish behavior needed for parity-oriented production solves.
On office, the explicit max_iter=100 QH GPU diagnostic now stops after
two finish blocks ([100, 100]) and reports non-convergence at the cap; the
same low-budget diagnostic previously spent five finish blocks and took about
56.2 s, while the capped source-tree run took about 42.9 s.
Earlier April 2026 diagnostics with the scan-heavy GPU policy are retained below for historical context:
Case |
CPU |
CPU warm cache |
GPU empty cache |
GPU warm cache |
|---|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The current practical policy is therefore:
use the non-scan VMEC-control loop for auto-selected production fixed-boundary solves on CPU,
use the scan-lifted loop for auto-selected production fixed-boundary solves on GPU/CUDA/ROCm,
keep explicit
--use-scanand--no-use-scanprofiler controls available for case-by-case comparisons,keep the persistent cache enabled by default so repeated GPU processes are not dominated by recompilation,
expose
--solver-device cpu/solver_device="cpu"for users who want to force CPU execution inside a GPU-enabled process,do not treat raw scan throughput as production performance unless the run also converges to the requested VMEC tolerance.
Raw solver throughput vs public policy overhead¶
The fixed-boundary profiler can now separate the requested raw solver path from
the public CLI-style policy. This matters because public defaults may add
dynamic scan probes, staged follow-up, or finish attempts around the requested
iteration budget. The profiler imports vmec_jax before importing JAX
directly so GPU allocator defaults and the persistent compilation-cache policy
match normal API/CLI runs. If neither --use-scan nor --no-use-scan is
provided, the profiler leaves use_scan=None and follows the same
backend-aware production policy as run_fixed_boundary. To benchmark the raw
accelerated scan path explicitly, use:
JAX_ENABLE_X64=1 python tools/diagnostics/profile_fixed_boundary.py \
--input examples/data/input.nfp4_QH_warm_start \
--iters 20 \
--simple-profile \
--no-multigrid \
--no-auto-cli-policy \
--solver-mode accelerated \
--use-scan \
--solver-device gpu \
--json-out /tmp/vmec_jax_qh20_raw_gpu.json
May 2026 raw-path diagnostics showed:
Case |
Device |
Iterations |
Timing |
Notes |
|---|---|---|---|---|
|
local CPU, JAX 0.9.2 |
20 |
|
raw accelerated scan |
|
|
20 |
|
old Python 3.10/JAX stack |
|
|
100 |
|
same host as GPU |
|
|
100 |
|
fixed overhead dominates |
|
|
20 |
|
warmed raw path |
|
|
20 |
|
warmed raw path |
The conclusion is narrower than “GPU is slow”: raw force iterations are fast
once warmed, but the available office stack still has high GPU fixed
overhead and uses Python 3.10 / JAX 0.6.2. The next GPU lane should use a
Python 3.11+ environment with current JAX, then target the remaining
compile/launch/replay overhead rather than changing physics tolerances or
forcing CPU fallback.
For high-mode runs, also profile the experimental reverse-adjoint scalar gradient callback:
python tools/diagnostics/profile_exact_optimizer.py \
--problem qa --max-mode 3 --inner-max-iter 300 \
--gradient-only --check-gradient
This path computes the gradient of 0.5 * ||r||^2 with one reverse replay
through the VMEC tape instead of propagating one tangent column per boundary
degree of freedom. It is exposed through
FixedBoundaryExactOptimizer.objective_and_gradient_fun and the opt-in
method="lbfgs_adjoint" / method="scalar_trust" optimizers. These are
currently profiling/experimental lanes, not the default QA/QH production path,
because the current reverse products are comparable to, but not consistently
faster than, the dense vectorized column replay for the present
max_mode <= 3 parameter counts.
The lbfgs_adjoint wrapper now enforces a hard scalar-gradient callback
budget because SciPy’s internal L-BFGS-B line search can otherwise exceed the
requested max_nfev. It also uses a conservative default trust box
(lbfgs_step_bound=0.01 in scaled parameter space) because unbounded L-BFGS-B
can probe extremely distorted boundaries on the first line search.
scalar_trust uses the same scalar-adjoint gradient but accepts only
monotone trust-region steps with a hard callback budget; this makes profiling
predictable even when L-BFGS-B line searches are ineffective. A May 2026 QH
max_mode=1 diagnostic confirms the budget is respected, but the
scalar-adjoint optimizers remain slower and less effective than dense exact
least-squares on that small case. The next useful step is therefore better
matrix-free/scalar trust-region behavior, not switching the default
least-squares path.
The accepted-point exact path uses the discrete-adjoint tape path by default on
both CPU and GPU. April and May 2026 diagnostics showed that a naive cold scan
exact Jacobian can be very slow, so scan remains an explicit
VMEC_JAX_OPT_EXACT_PATH=scan diagnostic path rather than the GPU default.
The production optimizer now avoids redundant residual-only exact executables,
uses backend-aware trial residuals, and reconstructs standard QS history
metrics from cached accepted-point residuals instead of re-solving the accepted
state after every Jacobian. In a fresh May 2026 office RTX A4000 profile of QH
max_mode=1 with inner_max_iter=trial_max_iter=20, forced scan took
102.9 s for one dense Jacobian callback while forced tape took 36.7 s;
the same local CPU-style QA max_mode=2 tape callback was 9.5 s. CPU
therefore remains faster for the small/medium diagnostics, so production runs
still document both CPU and GPU timings rather than claiming blanket GPU
speedups. GPU sweep production runs use calibrated optimizer budgets
(currently inner_max_iter =
trial_max_iter = 120 and ftol = trial_ftol = 1e-8 for deck-controlled
QA/QH cases), rather than the old four-evaluation diagnostic caps. Final
standalone verification runs can still use the VMEC input-deck NITER_ARRAY /
FTOL_ARRAY.
Replay and preconditioner JIT helper caches are retained across accepted points
and LRU-bounded. VMEC_JAX_SCAN_RUNNER_CACHE_SIZE and
VMEC_JAX_COMPUTE_FORCES_CACHE_SIZE default to 32 entries;
VMEC_JAX_STRICT_UPDATE_CACHE_SIZE defaults to 16 entries. Set any of
these to 0 to disable in-process retention when profiling RSS growth. Call
FixedBoundaryExactOptimizer.clear_caches() to release compiled replay
helpers explicitly after a long optimization batch. The optimizer still clears
heavyweight exact tapes between SciPy callbacks where needed to avoid RSS
growth.
The same exact Jv/J.Tv products can be used in SciPy’s trust-region
least-squares solver with method="scipy_matrix_free":
python tools/diagnostics/profile_exact_optimizer.py \
--problem qa --max-mode 3 --inner-max-iter 300 \
--method scipy_matrix_free --lsmr-maxiter 4 --max-nfev 2
This matrix-free trust-region lane is useful for profiling, memory-pressure
fallbacks, and larger parameter-count cases, but it is not a global default
because it is problem dependent. Use method="auto" as an opt-in,
device-preserving policy rather than as a promise of the fastest wall time for
every case. The current automatic policy may select matrix-free LSMR for
high-mode, stellarator-symmetric QS/QI problems on CPU/default CPU, where
cold-process and memory-pressure profiles motivated the lane. It still keeps
LASYM cases and explicit GPU runs on the dense SciPy path until case-specific
matrix-free GPU profiles show a benefit. The retained benchmark table below is
a useful counterexample: on that warm QA run, dense SciPy is still faster.
Validate the case you care about with the profiling commands before promoting
production timings. The trial residual path can be compared explicitly with
profile_exact_optimizer.py --trial-scan auto|on|off.
For high-parameter-count scalar-objective probes, use
method="auto_scalar" (alias "auto_adjoint"). This policy keeps the
same LASYM safeguards as "auto", but selects the safeguarded
scalar_trust reverse-adjoint path for high-mode, stellarator-symmetric
QS/QI CPU/default-backend cases. As of the 2026-06-01 production-policy
refresh, explicit GPU/CUDA/ROCm high-mode, stellarator-symmetric QS/QI
auto_scalar runs also use scalar_trust rather than falling back to
dense SciPy. Low-mode, LASYM, and ordinary method="auto" GPU runs stay on
dense SciPy unless the caller explicitly chooses another method. auto_scalar
does not enable cost-only trial filtering by default, since QH and QP GPU
profiles showed the extra forward trial solves can dominate short high-mode
runs. Enable that filter explicitly with scalar_cost_only_trials=True or
VMEC_JAX_OPT_SCALAR_COST_ONLY_TRIALS=1 when a rugged case needs trial
screening before exact gradient builds. This is the production entry point for
testing scalar-adjoint optimization without relying on environment variables.
A 2026-05-20 matrix-free cleanup removed one redundant initialization AD pass:
residual_linear_operator now obtains the frozen-axis initial-state
transpose from the already-created jax.linearize object instead of tracing a
second jax.vjp through the same initial-state graph. A later 2026-05-31
production pass moved SIMSOPT-style tuple workflows onto the same packed-state
cotangent hooks used by the older direct QS factories. For larger
stellarator-symmetric CPU/default matrix-free operators it can also precompute
the small initial-state tangent block; by default this is limited to
64–128 boundary DOFs because a 48-DOF QA microprobe showed the build cost can
exceed the single J.Tv transpose it saves. Override with
VMEC_JAX_OPT_LINEAR_OPERATOR_INITIAL_TANGENTS=1 or tune
VMEC_JAX_OPT_LINEAR_OPERATOR_INITIAL_TANGENT_{MIN,MAX}_DOFS for profiling.
New profiles therefore separate
linear_operator_initial_tangent_projection from the older
linear_operator_initial_transpose bucket; if the transpose bucket is still
large, the case has fallen back to the conservative no-precompute path.
The same 2026-06-01 pass added a scalar-gradient initial-tangent projection
path. On GPU/CUDA/ROCm high-mode, stellarator-symmetric scalar-adjoint runs,
objective_and_gradient_fun now precomputes the affine initial-state tangent
columns once per flip branch for 24–256 DOFs and projects the reverse tape
cotangent through that block. This targets repeated accepted-point scalar
gradients where rebuilding the initial-state VJP was visible in profiles. Tune
or override with VMEC_JAX_OPT_SCALAR_GRADIENT_INITIAL_TANGENTS and
VMEC_JAX_OPT_SCALAR_GRADIENT_INITIAL_TANGENT_{MIN,MAX}_DOFS.
The validation profile for this change used the QA max_mode=3 48-DOF probe
with two VMEC iterations per accepted/trial solve. On the office GPU,
method="auto_scalar" selected scalar_trust and completed one cold
callback in 13.43 s; the profile showed gradient_initial_tangents and
gradient_initial_projection rather than gradient_initial_vjp. A
standalone two-repeat GPU scalar-gradient callback took 12.22 s for the
first point and 0.03 s for a same-point repeat, with the second repeat using
gradient_initial_tangents_cache_hit. A matched dense SciPy GPU callback
still took 29.84 s for this one-callback diagnostic, dominated by
jacobian_projected_replay_total; dense SciPy was 2.22x slower than the
new auto_scalar route. These numbers are diagnostics for the
small cold probe, not a production optimization benchmark, but they confirm the
new route avoids dense-Jacobian fallback on eligible GPU scalar-adjoint runs.
A larger QH max_mode=4 optimizer run on office then compared the
production dense SciPy path against method="scipy_matrix_free" and
method="auto_scalar" with max_nfev=3, inner_max_iter=80, and
trial_max_iter=80. Dense SciPy remained the best production choice for
this high-mode GPU case: it finished in 94.8 s with final objective
1.74e-1. scipy_matrix_free took 213 s and stopped at
2.50e-1 because repeated Jv/J.Tv products dominated the trace.
auto_scalar took 95.6 s but stopped at 2.85e-1 because the current
scalar trust-region globalization did not make enough progress in the same
budget. Therefore explicit GPU method="auto" still preserves dense SciPy
for high-mode QS/QI optimizations; scalar and matrix-free paths remain opt-in
diagnostics until longer production sweeps show comparable objective progress.
The follow-up patch tested JIT compilation of the repeated initial-state
construction inside that accepted-point path. FixedBoundaryExactOptimizer
can lazily JIT the parameter-to-packed-initial-state map with
VMEC_JAX_OPT_JIT_INITIAL_STATE=1. It is not enabled by default because
new cold CPU exact callbacks pay more compile/dispatch overhead than this small
projection saves. On a QH max_mode=2 CPU smoke, the default non-JIT path
cut a cold accepted callback from 3.23 s to 2.13 s relative to the JIT
helper, with initial_guess_exact dropping from 1.58 s to 0.154 s.
The next hot bucket remains the compiled accepted VMEC tape solve/replay path,
not initialization.
April 2026 local CPU diagnostics with inner_max_iter=trial_max_iter=300 and
max_nfev=2:
Case |
Method |
Wall time |
Notes |
|---|---|---|---|
QA |
dense exact LSQ |
|
|
QA |
matrix-free LSQ, |
|
same objective as dense, slower wall time |
QH |
dense exact LSQ |
|
|
QA |
reverse scalar gradient |
|
experimental; current L-BFGS wrapper needs optimizer work |
The same QA dense run took 110.6 s in a cold process, so benchmark reports
must identify cold and warm timings separately.
Before using this lane for production results, validate the products on the case of interest:
python tools/diagnostics/profile_exact_optimizer.py \
--problem qa --max-mode 1 --inner-max-iter 20 \
--check-linear-operator
The forward product Jv is covered by the same dynamic replay path as the
dense exact Jacobian. The reverse product J.Tv uses split residual-block
transposes so inactive aspect/iota/QS blocks are not differentiated with zero
cotangents through singular axis branches. On current diagnostics,
Jv and J @ X match dense to near roundoff; QA J.Tv matches to
about 1e-6 relative error because current-driven iota still needs
axis-gauge cotangent cleanup, while QH J.Tv matches to near roundoff.
Current performance (representative benchmarks)¶
The README-facing fixed-boundary CPU matrix is generated from
docs/_static/figures/readme_runtime_compare.csv and visualized in:
The current data should be read as a performance reality check, not as a broad
single-solve speedup claim. On the current local matrix, warm vmec_jax
beats VMEC2000 on 1 of 16 bundled fixed-boundary rows
(circular_tokamak_aspect_100, about 1.33x faster). The median warm
single-solve row is still about 4.4x slower than VMEC2000 on this host.
Cold runs are slower because they include XLA compilation and runtime setup.
A broader 2026-05-24 internal policy matrix compared the default fixed-boundary
policy against the explicit accelerated policy on all 35 bundled
fixed-boundary input decks. On the local CPU host, 34 rows completed within the
per-case timeout: 20 were faster with accelerated, 14 were slower, the
median warm speedup was 1.17x, and the geometric-mean warm speedup was
1.61x. This confirms that CPU policy remains case-dependent rather than a
single global scan/default choice. The largest CPU wins were low-latency or
simple-seed rows such as solovev (16.9x),
up_down_asymmetric_tokamak (15.3x), cth_like_fixed_bdy
(11.1x), and nfp4_QH_warm_start (7.17x). The largest CPU losses
were high-work or tokamak rows such as LandremanPaul2021_QA_lowres
(0.20x), DSHAPE (0.33x), shaped_tokamak_pressure (0.38x),
and purely_toroidal_field (0.42x). NuhrenbergZille_1988_QHS timed
out in both policies and remains a stress fixture rather than a promotion row.
The same 2026-05-24 matrix on office with JAX_PLATFORMS=cuda,cpu showed
the opposite policy conclusion for CUDA: 33 rows had valid warm
accelerated-vs-default speedups within CUDA, 32 of those were faster with
accelerated, the median warm speedup was 2.06x, and the geometric-mean
warm speedup was 3.38x. Only
LandremanPaul2021_QA_lowres was slower (0.33x), and
NuhrenbergZille_1988_QHS again timed out. Representative CUDA wins were
li383_low_res (12.6x), nfp1_QI (11.9x), minimal-seed NFP
families (~10-11x), and nfp4_QH_warm_start (10.3x). These numbers
are the reason public accelerator fixed-boundary runs prefer the scan-backed
accelerated path while CPU remains backend/case aware.
This does not contradict the optimization motivation: the exact-adjoint optimization path can avoid the many finite-difference VMEC subprocess columns that SIMSOPT+VMEC2000 needs. It does mean that single-solve CPU/GPU runtime remains an active performance lane before claiming broad VMEC2000 runtime wins.
Finite-beta CPU profile: May 2026¶
The examples/data/input.nfp4_QH_finite_beta case is a useful stress test
because it combines finite pressure/current profiles, mpol=5, ntor=5,
and a two-stage NS_ARRAY=[5, 51] schedule. On the local Apple CPU host,
VMEC2000 converges this case in about 3.3 s. Before the May 2026 profile
cleanup, the public vmec_jax CPU path took about 25.3 s cold for the
same converged multigrid run. After moving concrete profile evaluation and the
CPU force helper post-processing out of residual JAX dispatch fragments, the
same diagnostic run takes about 20.6 s cold while preserving the final
residual (~5.6e-13).
A later pass moved ordinary, non-differentiated multigrid interpolation through
the NumPy host path while keeping traced/autodiff interpolation on the JAX path.
On the same diagnostic command this reduced the local cold time to about
20.1 s and preserved the final residual (~5.6e-13). It also avoids
unnecessary one-time JAX dispatch during the radial NS_ARRAY stage transfer.
The 2026-05-23 CPU pass made the VMEC-output writer use the fast bcovar path by
default. On the same local host, the cold full CLI time for
input.nfp4_QH_finite_beta dropped from about 18.8-19.1 s to about
16.8 s while bundled wout parity tests continued to pass. The corresponding
input.nfp4_QH_warm_start full CLI time dropped from about 6.8 s to
about 4.6 s. Set VMEC_JAX_WOUT_FAST_BCOVAR=0 only for legacy
force-kernel output debugging.
The 2026-05-24 CPU pass fixed two steady-state host-loop overheads: NumPy
copies of the cached VMEC phase stack now keep the zero-cost phase-stack branch,
and the scalxc[:, modes.m] table used by real-space synthesis is cached for
static mode/surface grids. The full public CLI compute time on the local Apple
CPU host is now about 12.2 s for input.nfp4_QH_finite_beta and about
1.8 s for input.nfp4_QH_warm_start. End-to-end shell time is still
higher (about 16.3 s and 5.5 s respectively) because a cold process
includes Python/JAX startup and the first netCDF4 import/write for WOUT output.
An opt-in follow-up knob, VMEC_JAX_WOUT_REUSE_FINAL_BCOVAR=1, lets WOUT
assembly reuse the converged final force/bcovar payload already produced by the
solve when that payload is available. This is intended for profiling and cold
CLI runtime experiments on converged runs; the default remains an independent
fast-bcovar WOUT reconstruction so parity and diagnostics stay conservative.
Another opt-in knob, VMEC_JAX_WOUT_IO_WARMUP=1, starts the heavy netCDF4
import in a background thread while the solve is running. It is disabled by
default because the May 2026 cold CPU measurements showed that the background
import contended with the VMEC solve and was slightly slower overall.
The same local CPU profiling pass also isolated solve-only cost with
profile_fixed_boundary.py and no wout write. A 2026-05-25 policy check
showed that host NumPy update assembly is still best for low-mode warm-start
QH: input.nfp4_QH_warm_start converged in 1.22 s solve-only with final
residual 1.1e-13. The larger finite-beta QH case is the opposite:
the host-update path spent 8.62 s in state update and took 21.6 s wall
time, while the workload-sensitive fused strict-update path spent only
0.33 s in state update and converged in 11.9 s wall time with final
residual 5.5e-13. This is still slower than the local VMEC2000 finite-beta
runtime, but it removes the largest pure-Python/host assembly term from the
current CPU profile.
A short CPU TensorBoard trace
showed about 440 pjit cache misses and 9 backend compile/load events
in the cold path, but the long-run steady bottleneck is now the compiled
vmec_forces_rz_from_wout / tomnsps_rzl transform stack.
With an explicit CPU persistent cache
(VMEC_JAX_COMPILATION_CACHE=1 and an isolated
VMEC_JAX_COMPILATION_CACHE_DIR), repeated fresh-process CLI runs improve
further, but CPU cache use remains opt-in because XLA:CPU cache entries are
native AOT executables and can emit host-feature mismatch warnings if reused
across incompatible machines or JAX/JAXLIB versions.
The current bottleneck is therefore no longer the cubic current/pressure
profile helper or the wout force recomputation. The remaining gap is dominated
by per-iteration force synthesis/residual transforms and the host VMEC-control
loop. Future single-solve CPU work should target fewer materializations in
vmec_forces_rz_from_wout/tomnsps_rzl and a less fragmented cold-start
setup path before claiming VMEC2000-competitive finite-beta runtime.
On the office RTX A4000 GPU, ordinary scan fixed-boundary solves now show a
cleaner split between cold compilation and warmed throughput. With the current
Python 3/JAX stack, input.nfp4_QH_warm_start takes about 27 s in a fully
cold fresh process but 0.71 s when warmed in the same process. The heavier
input.nfp4_QH_finite_beta case takes about 31.6 s cold and 3.39 s
when warmed in the same process, with final total residual 5.51e-13. This
means the warmed GPU finite-beta solve is now VMEC2000-level for this case; the
remaining GPU production issue is fresh-process compile/cache latency.
Public GPU requests now enable the persistent compilation cache automatically
even when the user selects the GPU through solver_device="gpu" instead of a
JAX platform environment variable. The runtime GPU path mirrors the
import-time cache setup by setting both jax_compilation_cache_dir and the
GPU XLA autotune cache option before first solve compilation. A post-fix
office check showed that this alignment is necessary for correct default
configuration, but not sufficient to make fresh-process GPU solves reuse the
compiled scan executable on this stack. Re-running from the checked-out source
tree with PYTHONPATH=$PWD gave repeated QH warm-start processes of
16.3 s and 16.2 s, and repeated finite-beta QH processes of
55.9 s and 55.5 s. The dominant term remains
scan_device_dispatch_s (about 12.5 s for QH warm start and 41.9 s
for finite beta), so the next GPU lane is persistent executable reuse or
reducing scan compile/dispatch cost rather than more driver-side cache toggles.
The first compile/dispatch reduction that survived the finite-beta guardrail is
an adaptive accelerator chunk: quiet GPU scans with more than 512 iterations use
a fixed 512-iteration scan chunk, regardless of Fourier mode count. On
office this is neutral for the low-mode input.nfp4_QH_warm_start case
(13.49 s default, 13.43 s at 512, 15.90 s for one full chunk) and
material for the high-mode finite-beta QH case (54.28 s default/full,
37.84 s at 256, 23.48 s at 512, 51.33 s at 1024), with unchanged
final residuals. That moves the finite-beta GPU path back to the same order as
the local CPU timing while preserving the warmed scan-throughput advantage.
The exact-optimizer profile has a different bottleneck from single fixed-boundary
solves. A 2026-05-24 bounded max_mode=3 two-evaluation profile on the
minimal-seed QA/QH/QP examples measured:
case |
backend |
warm profile time |
dominant measured term |
note |
|---|---|---|---|---|
QA |
CPU |
|
|
objective moved, iota stayed on the zero-transform branch |
QA |
CUDA |
|
|
persistent cache reduces but does not remove cold cost |
QH |
CPU |
|
|
no accepted movement in the two-evaluation diagnostic |
QH |
CUDA |
|
|
same-process warm is materially better than fresh-process warm |
QP |
CPU |
|
|
minimal seed stayed on the zero-transform branch |
QP |
CUDA |
|
|
not production-competitive yet |
QI |
CPU |
|
Boozer/QI objective setup and exact solve |
one-evaluation timing only, not a converged QI run |
QI |
CUDA |
|
Boozer/QI objective setup and exact solve |
GPU is still slower for this cold QI profile |
These optimization profiles point to the next performance target: accepted-point exact tape build/replay, tangent construction, and zero-transform branch avoidance in simple-seed QP/QI starts. The fixed-boundary CUDA scan path is now often faster than the host loop, but full optimization does not become production-fast until exact callbacks and Boozer/QI objective setup are less cold-start dominated.
The finite-beta CPU timing above was checked against the local VMEC2000
executable ~/bin/xvmec2000 in a temporary directory. The VMEC2000 run took
about 3.3 s. The full public vmec_jax CPU CLI path measured about
12.2 s before the workload-sensitive strict-update policy; the latest
solve-only profile is 11.9 s on the local CPU host, with the state-update
term reduced from 8.62 s to 0.33 s. The resulting WOUT comparison remained in the expected
converged-equilibrium parity band: representative relative RMS differences were
~1.8e-6 for iota, ~3.5e-6 for rmnc, ~2.2e-5 for zmns,
~1.0e-5 for bmnc, and ~1.7e-4 for lmns after skipping the
near-axis cells.
A 2026-05-23 cold-process cProfile pass on a bounded finite-beta run showed
the same split at shorter iteration budgets: package/JAX startup took about
1.9 s of a 4.9 s profiled process, backend compile/load about 0.9 s,
and the solver stage about 2.4 s. The largest steady solver terms were
vmec_forces_rz_from_wout (0.77 s),
vmec_bcovar_half_mesh_from_wout (0.47 s), and tomnsps_rzl
(0.19 s). That older bounded profile was low enough work that forcing
VMEC_JAX_HOST_UPDATE_ASSEMBLY=0 was slower; the current production policy
therefore uses the host path for low-work CPU stages and the fused strict-update
path only above the VMEC_JAX_HOST_UPDATE_CPU_WORK_LIMIT threshold.
The figure rows and provenance are available as:
Regenerate the current fixed-boundary plot after a runtime sweep with:
python tools/diagnostics/readme_runtime_compare.py \
--cpu-summary outputs/fixed_runtime_accel_cpu_bundle_20260406_r2/summary.json \
--figure-kind fixed --plot-mode runtime \
--figure-out docs/_static/figures/readme_runtime_compare.png \
--csv-out docs/_static/figures/readme_runtime_compare.csv \
--json-out docs/_static/figures/readme_runtime_compare.json
When a same-host GPU sweep is available, add one or more --gpu-summary
paths from the matching run. The generated CSV/JSON/table then retain
separate CPU and GPU vmec_jax cold/warm runtime and memory columns, which
keeps the README plot from silently dropping GPU comparison data.
Profiling and diagnostics
Enable float64¶
VMEC2000 is float64-first. For parity, enable x64 in JAX:
export JAX_ENABLE_X64=1
JIT boundaries and compile latency¶
On CPU, compilation can dominate runtime for moderate problem sizes. vmec-jax uses:
a jitted geometry kernel (
eval_geom),non-jitted solver gradients by default (to reduce compile latency).
Solver functions accept jit_grad=True to trade longer compile time for faster
iterations.
To reduce initial compilation overhead during startup, you can disable JIT for the initial guess phase by setting:
export VMEC_JAX_DISABLE_JIT_INIT=1
This keeps the solver kernel JIT-compiled, but avoids compiling the initial boundary->state projection path (useful for short runs or rapid profiling).
To reduce per-iteration latency spikes in multigrid runs, vmec-jax can
precompile the force kernel at the start of each stage. This is enabled by
default when jit_forces=True; you can override it with:
export VMEC_JAX_JIT_PRECOMPILE=0
If you prefer to run a few iterations without JIT before compiling, set:
export VMEC_JAX_JIT_WARMUP_ITERS=2
Scan-mode iteration (fast path)¶
The scan-based loop lifts the VMEC2000 iteration into jax.lax.scan to reduce
Python overhead. Public fixed-boundary defaults are backend-aware: CPU uses the
profiled VMEC-control non-scan loop, while GPU/CUDA/ROCm uses scan for ordinary
raw fixed-boundary solves. You can explicitly request scan with:
--faston the CLI,use_scan=Trueinrun_fixed_boundary,or
VMEC_JAX_USE_SCAN=1.
You can force the VMEC-control loop with use_scan=False or the diagnostics
flag --no-use-scan.
Important: scan parity is case-dependent on difficult large-ns stages.
Use scan when profiling a case where the scan loop has been validated. You can
always force the conservative path with --parity.
For LASYM fixed-boundary stages in explicit scan mode, the selector can use:
a timed scan/non-scan probe on CPU backends,
a short parity-only probe on accelerator backends.
This keeps explicit GPU scan experiments from paying the full warmed non-scan timing cost while still rejecting scan when the short parity probe disagrees.
Controls:
VMEC_JAX_DYNAMIC_SCAN_TIMED=1: force a timed probe even on accelerators.VMEC_JAX_DYNAMIC_SCAN_TIMED=0: force parity-only probing.VMEC_JAX_DYNAMIC_SCAN_ITERS=<int>: override the probe window (defaults to10on CPU,3on accelerators).
For quiet accelerator scans, vmec-jax uses a backend-aware scan chunk.
Short runs use one full-length chunk. Long-budget runs use fixed
512-iteration chunks to reduce fresh-process compile/dispatch latency while
reusing the same compiled body inside the solve.
Controls:
VMEC_JAX_SCAN_CHUNK_SIZE=<int>: override the chunk target explicitly.
Debug dump env vars are incompatible with scan mode.
Experimental accelerated mode¶
vmec-jax now exposes an explicit experimental solver policy for the
non-parity performance track:
Python API:
run_fixed_boundary(..., solver_mode="accelerated")CLI:
vmec input.name --solver-mode accelerated
Current behavior of this first slice:
auto-selected public fixed-boundary API/CLI runs use the profiled non-scan VMEC-control loop on CPU and the scan-lifted loop on GPU/CUDA/ROCm,
explicit
solver_mode="accelerated"callers keep the historical scan default unlessuse_scan=Falseis supplied,ordinary non-scan production runs skip the parity-oriented scan-selection probes; use the dynamic scan controls below only for targeted diagnostics,
when the caller does not explicitly request multigrid, accelerated fixed-boundary runs now default to a single final-grid stage. This avoids per-stage interpolation and recompilation overhead that was dominating the heavy bundled fixed-boundary cases,
accelerated fixed-boundary stages still use a scalar total-residual target derived from the input
ftolbudget as a cheap in-block early-stop:fsq_total_target = ftol * 3for the three VMEC residual channels (fsqr,fsqz,fsql). However, the returned accelerated fixed-boundary run now accepts convergence only when the final-stage per-channel rule is satisfied, matching the requestedFTOLliterally onfsqr,fsqz, andfsql,the experimental solver controls no longer rely on fixed absolute convergence thresholds. By default:
gradient-based stopping derives
grad_tolfrom the initial gradient scale and machine precision,the Gauss-Newton path derives its CG tolerance from the current residual progress against the same
ftolbudget,the Gauss-Newton damping seed is derived from the local normal-equation curvature scale instead of a fixed literal damping floor,
residual-objective
m=1release thresholds now default toftolinstead of hardcoded residual cutoffs,
accelerated runs now request compact histories and a minimal resume payload by default, so the result object does not carry the full parity-era momentum/preconditioner cache unless the caller explicitly asks for it,
the CLI executable now has an extra fixed-boundary-only policy layer on top of accelerated mode:
the first attempt is the same fast final-grid solve used by the optimized Python API path,
if a staged input provides explicit
NS_ARRAY/NITER_ARRAYand the fast final-grid attempt misses the target, the CLI replays that staged schedule automatically (accelerated coarse stages, strict parity final stage),if the input is staged but does not provide
NITER_ARRAYand the user explicitly forces accelerated mode, the CLI falls back to a reduced warm-start multigrid budget derived from the coarsest-to-finestnsratio,strict parity finish blocks then continue from state only, without reusing the parity-era nonlinear-controller caches,
free-boundary cases currently stay on the existing robust path; accelerated free-boundary control is not implemented yet,
the mode is intended to reduce control overhead while preserving final residual quality, not to reproduce the VMEC2000 iteration trace.
Use the dedicated comparison harness to evaluate it against the current default solver policy:
python tools/diagnostics/benchmark_accelerated_mode.py \
--baseline-mode default \
--candidate-mode accelerated \
--candidate-cli-fixed-boundary-mode \
--kind fixed \
--jax-platforms cpu
The harness reports:
cold and warm runtime,
peak process memory,
final
fsq_total,convergence flags,
reference-
woutrelRMS metrics when bundled references are available.
Early March 2026 smoke results on the local CPU host:
input.up_down_asymmetric_tokamak: about4.1xwarm speedup with a materially smaller memory footprint than the current default path,input.circular_tokamak: approximately neutral in runtime, with good final quality (~1.2e-5reference-woutrelRMS),input.LandremanPaul2021_QA_lowres: approximately neutral with the current ftol-derived total target,free-boundary accelerated mode is currently a control-path alias for the robust baseline, not a new fast free-boundary controller.
Serial fixed-boundary follow-up measurements from
outputs/accelerated_fixed_boundary_reassessment_20260309/summary.json
show why the single-grid default is now the accelerated fixed-boundary policy:
input.LandremanSenguptaPlunk_section5p3_low_res:45.48scurrent default vs0.198saccelerated single-grid and0.232saccelerated explicit multigrid; the accelerated single-grid route converges and is dramatically faster than both,input.LandremanPaul2021_QA_lowres:8.18scurrent default vs7.31saccelerated single-grid and8.10saccelerated explicit multigrid; the accelerated single-grid route now carries the full staged iteration budget and converges at~3.0e-13,input.LandremanPaul2021_QA_reactorScale_lowres:21.15swarmed on the optimized CLI track versus43.20sfor VMEC2000 on the current bundled CPU benchmark, showing the same controller policy carries over to a heavier reactor-scale 3D case.
The fixed-boundary CLI path is now best understood as a controller stack, not a single algorithm:
easy inputs stay on the fast final-grid optimized path,
staged inputs can automatically escalate into their input-defined stage schedule before paying the cost of parity finish blocks,
only the genuinely hard cases should reach the final strict continuation phase.
The last pre-PR cleanup on codex/nonparity-performance did not change the
controller policy again. Instead, it trimmed overhead around the existing fast
path:
performance-oriented non-verbose staged runs now default to the lighter history footprint, not just the explicitly accelerated subset,
ordinary free-boundary runs now skip extra axis syntheses that were only needed for
VMEC_JAX_DUMP_SCALPOTdiagnostics,the VMEC-like dense free-boundary solve path now reuses cached LU factorizations when SciPy is available, with a NumPy fallback otherwise.
For an up-to-date side-by-side comparison on your machine, use the bundled driver example:
python examples/fixed_boundary_driver_tracks.py \
examples/data/input.circular_tokamak \
--quiet --json
On the current branch, that example produced the following local CPU smoke
result for input.circular_tokamak:
parity track:
28.863swithfsq_total ~2.04e-14,optimized CLI-style track:
3.445swithfsq_total ~2.85e-14.
That example uses the same public Python driver entry point as library users,
but it enables cli_fixed_boundary_mode=True on the optimized path so the
controller matches the executable behavior exactly.
Latest serial bundled fixed-boundary reassessment (April 2026)¶
Historical note: this April 2026 accelerated-branch snapshot is retained to
explain why the optimized controller exists, but it is not the current public
VMEC2000 comparison. Use the README-facing CSV/JSON in the previous section
for current release claims. This snapshot used NS=151 single-grid inputs
(examples_single_grid/data/) and compared
solver_mode="accelerated" warm runtimes against VMEC2000.
Results are in outputs/bench_accel_20260413/summary.json.
vmec_jax accelerated mode (warm) vs VMEC2000 on NS=151 single-grid:
ITERModel:0.44svs1.72s— 3.9x fasterup_down_asymmetric_tokamak:1.62svs7.03s— 4.3x fasterB2_A80:0.16svs0.81s— 5.1x fasterLandremanSenguptaPlunk_section5p3_low_res:0.23svs1.00s— 4.3x fastercircular_tokamak:0.62svs1.50s— 2.4x fastershaped_tokamak_pressure:0.77svs1.99s— 2.6x fastersolovev:0.26svs1.23s— 4.7x fasternfp4_QH_warm_start:0.96svs1.71s— 1.8x fasterbasic_non_stellsym_pressure:3.61svs8.38s— 2.3x fastercircular_tokamak_aspect_100:0.43svs0.84s— 2.0x fasterpurely_toroidal_field:1.44svs1.37s— roughly neutralLandremanPaul2021_QA_lowres:57.26svs35.18s— 1.6x slowerLandremanPaul2021_QA_reactorScale_lowres:53.81svs34.57s— 1.6x slowerLandremanPaul2021_QH_reactorScale_lowres:51.14svs32.98s— 1.6x slowerLandremanPaul2021_QA_lowres1:32.33svs16.03s— 2.0x slowercth_like_fixed_bdy:23.66svs4.19s— 5.7x slower
11 of 16 cases are faster in warm accelerated mode than VMEC2000. The 5 slower cases (QA/QH reactor-scale and cth_like) are the main remaining fixed-boundary performance targets.
Note: cold (first-run) times include XLA compilation and are 5–30x the warm times. Cold compilation is a one-time cost per JIT-distinct input configuration.
Earlier reassessments from March 2026:
results recorded in
outputs/accelerated_cli_fixed_boundary_full_20260311_r2/summary.jsonall 16 bundled fixed-boundary cases converge on both the baseline and optimized paths,
the optimized path is now faster on 13 of 16 cases and roughly neutral on the remaining 3,
the earlier runtime-regression blocker on the bundled CPU matrix is gone.
Final-wout accuracy is a separate gate from residual convergence. The
earlier full fixed-boundary audit is recorded in
outputs/fixed_wout_audit_20260310_r3/summary.json, and the later staged-3D
controller fixes improved several non-axisymmetric cases materially:
strong final-
woutagreement on the current shipped showcase cases:ITERModel(max relRMS6.01e-06),shaped_tokamak_pressure(1.55e-07),circular_tokamak(1.03e-05),targeted staged non-axisymmetric follow-up fixes then brought the reactor QA/QH Fourier-channel errors down to the branch target range on direct comparisons:
LandremanPaul2021_QA_lowresnow reaches aboutrmnc 5.83e-05,zmns 2.83e-04,lmns 4.75e-03;LandremanPaul2021_QA_reactorScale_lowresreaches aboutrmnc 2.49e-05,zmns 1.61e-04,lmns 2.86e-03;
Optimization and gradient benchmarking¶
For optimization workflows, runtime alone is not the right metric. The project needs to track at least four numbers per case:
primal solve time,
explicit-gradient time,
implicit-gradient time,
peak memory for each of the above.
The current repo already has the pieces needed to build this:
explicit-diff and implicit-diff examples in
examples/optimization/,implicit solver support in
vmec_jax/implicit.py,profiling hooks in
tools/diagnostics/.
What is still missing is a canonical benchmark matrix and reporting format. The recommended first benchmark set is:
circular_tokamakfor small axisymmetric behavior,ITERModelfor a larger axisymmetric case,LandremanPaul2021_QA_lowresfor a representative 3D case.
For each case, record:
backend,
primal runtime,
explicit-gradient runtime,
implicit-gradient runtime,
peak memory,
final objective / residual quality.
This will tell us where the real optimization bottlenecks are, and it will
also guide downstream integration work for booz_xform_jax and neo_jax.
the runtime picture is now favorable on the bundled CPU matrix, but the branch remains experimental because the non-parity scope and GPU/default policy questions are broader than this one fixed-boundary CPU result.
a later
woutaudit found that much of the remaining QA/QH benchmark error was coming from symmetry-forbidden geometry channels being exported forlasym=False: zeroingrmnsandzmncinwoutfor symmetric runs reduced the bundled 3D quality metric from about3.37e-01to4.19e-02onLandremanPaul2021_QA_lowres, from3.56e+00to3.14e-02onLandremanPaul2021_QA_reactorScale_lowres, and from4.61e+00to2.22e-02onLandremanPaul2021_QH_reactorScale_lowresinoutputs/fixed_wout_3d_audit_20260311_r1/summary.json,the next narrowed audit then focused on the staged current-driven 3D continuation policy itself: for 3-stage
lasym=Truecurrent-driven runs, the first attempt used a mixed controller that kept the entry/final stages conservative and accelerated only the interior stage,that same audit kept the remaining QH reactor-scale parity error bounded:
LandremanPaul2021_QH_reactorScale_lowresreached aboutrmnc 6.12e-05,zmns 2.60e-04,lmns 9.97e-03,that change materially reduced the remaining non-axisymmetric lambda drift:
basic_non_stellsym_pressureimproved from about3.46e-01to3.46e-02max relRMS while still running faster than baseline (about23.69sbaseline vs19.36soptimized in the targeted audit), and the QA/QH reactor-scale cases held at about3.14e-02and2.22e-02with large runtime wins,the remaining bundled 3D quality gap is therefore much narrower and now mostly a lambda-field accuracy question rather than a broad geometry or force-balance mismatch,
a final targeted controller split closed most of that remaining gap:
lasym=Falsecurrent-driven 3D CLI runs now go straight to staged multigrid on the conservative non-scan residual path,with that split, the latest targeted audit reached:
LandremanPaul2021_QA_lowresabout4.20e-03max relRMS at about100.6swarmed runtime,LandremanPaul2021_QA_reactorScale_lowresabout6.42e-04at about125.1s,LandremanPaul2021_QH_reactorScale_lowresabout6.00e-05at about180.2s, whilebasic_non_stellsym_pressureremained the last branch-specific lambda outlier,the follow-on strict-
FTOLpass removed that branch-specific regression by keepinglasym=Truecurrent-driven 3D staged runs fully on the conservative controller:basic_non_stellsym_pressurenow lands back at about2.98e-02max relRMS, matching the current baseline quality instead of worsening it, with essentially neutral warmed runtime (about22.24sbaseline vs22.31soptimized),a subsequent CPU-focused profiling pass showed that the next avoidable cost was not force balance itself but controller-side JAX overhead: moving
ptausign-change detection to a host NumPy implementation removed it from the hot list, and a follow-on host update-assembly path for acceleratedlasym=FalseCPU CLI solves moved the next hotspot out of the repeated signed-mode conversion path. On the targeted reassessment artifactoutputs/host_updates_benchmark_20260312/summary.json,LandremanPaul2021_QA_lowresimproved from about34.83sbaseline vs38.64soptimized before the host update path to about34.83sbaseline vs31.17soptimized after it, while keeping about4.20e-03max relRMS against the VMEC2000 reference. The sensitivebasic_non_stellsym_pressurecase also held baseline-level quality and remained slightly faster (about9.12sbaseline vs8.95soptimized),rerunning the full warmed bundled
lasym=Falsefixed-boundary CPU matrix on that new head produced the cleanest branch-level fixed-boundary result so far: all 13 bundledlasym=Falsecases converged on both paths, and the optimized CLI controller was faster on all 13. Representative rows fromoutputs/fixed_lasym_false_matrix_20260312/summary.jsonincludeLandremanPaul2021_QA_reactorScale_lowres(51.31sbaseline vs38.56soptimized),LandremanPaul2021_QH_reactorScale_lowres(60.10svs46.33s),ITERModel(12.73svs5.00s), andcth_like_fixed_bdy(4.71svs0.97s),the README-facing VMEC2000 comparison was then rerun separately on the same host in
outputs/readme_fixed_runtime_vmec2000_accel_cpu_20260312/summary.json: all 13 bundledlasym=Falsefixed-boundary cases converged, but the optimized branch is still faster than VMEC2000 on only the smallest shipped cases (solovevandcircular_tokamak_aspect_100). The reactor-scale QA/QH cases are now close enough to compare honestly on one plot, but they still run somewhat slower than VMEC2000 on CPU,carrying the same “reduce host-controlled overhead” approach into
lasym=Falsefree-boundary showed the next safe win: batching the boundary real-space syntheses invmec_jax/free_boundary.py:_sample_external_boundary_arrayscut the representativeinput.cth_like_free_bdycProfile total from about60.41sto about58.21swhile keeping the NESTOR reuse tests green,the next free-boundary profile then showed that
_vmec_nonsingular_terms_from_bexniand_vmec_nonsingular_gsource_from_bexniwere still rebuilding basis-only helper tables on every call; caching those tables on the free-boundary basis cut the same representative cProfile total further to about32.67sand brought the direct warmed CPU benchmark forcth_like_free_bdyto about11.25sinoutputs/freeb_cth_runtime_20260312/summary.jsonwhile preserving convergence.one more pass then removed six separate second-derivative boundary synthesis calls from
_sample_external_boundary_arraysand replaced them with two stacked batched syntheses. On the same representative free-boundary case, that moved the cProfile total further to about31.04sand improved the warmed CPU benchmark to about10.41sinoutputs/freeb_cth_runtime_20260312_r2/summary.jsonwhile keeping the direct NESTOR reuse tests green.the next step replaced the remaining JAX-backed boundary synthesis in that host-only external-sampling path with a cached NumPy phase-stack helper. That keeps the same VMEC trig algebra but removes more JAX lowering/indexing overhead from
_sample_external_boundary_arraysitself. On the same representative case, the cProfile total dropped again to about30.20swith_sample_external_boundary_arraysdown to about5.78s, and the warmed CPU benchmark improved further to about9.86sinoutputs/freeb_cth_runtime_20260312_r4/summary.json.a deeper 2026-03-13 profiling pass then combined full
pytest -q/ Sphinx validation with kernel-level auditing:the accelerated fixed-boundary HLO dump for
input.LandremanPaul2021_QA_lowresshowed the bcovar kernel still contains a large gather/scatter footprint (roughly96gathers and55scatters in the dumped HLO), which means the next fixed-boundary wins are more likely to come from refactoring bcovar/state-enforcement kernels than from controller tweaks,cProfile on the representative free-boundary case
input.cth_like_free_bdyshowed the remaining dominant host-side cost was still insidefree_boundary.py:_sample_external_boundary_arraysplus JAX-backed parity conversions and the full-profilecurrents_from_bcovarpath used only to recover the edgectorscalar for NESTOR,replacing those host-only parity conversions with pure NumPy helpers in
vmec_parity.pyand adding a specializedvmec_lforbal.plascur_edge_from_bcovarhelper cut the same warmed 200-iterationinput.cth_like_free_bdybenchmark from about8.00sto about3.45son the same CPU host while keeping the free-boundary regression tests green,after that change, the same cProfile run dropped from about
8.00sto about6.06stotal Python-side wall time, and the remaining dominant CPU hotspots became the update/preconditioner side (_apply_vmec_scale_m1_precond_rhs,_enforce_fixed_boundary_and_axis, and the free-boundary update block) rather than the old external-boundary sampling path itself.
Full same-host readiness sweep¶
After the targeted free-boundary optimizations, the branch was rerun on the full shipped example matrix plus the external DIII-D axisymmetric free-boundary references. The freshest final-head artifacts are:
fixed-boundary optimized CLI / automatic Python readiness matrix:
outputs/readiness_fixed_all_20260313/summary.jsonfixed-boundary VMEC2000-vs-optimized warmed runtime matrix:
outputs/fixed_runtime_vmec2000_accel_cpu_warm_20260313/summary.jsonfree-boundary VMEC2000-vs-default warmed runtime matrix:
outputs/free_runtime_vmec2000_cpu_warm_20260313/summary.json
Current state:
fixed-boundary: all 16 rows converged on the optimized branch path,
fixed-boundary accelerated/default comparison: faster on 13 rows, roughly neutral on 1, slower on 2, with all 16 satisfying the requested final-stage
FTOL,free-boundary: all 5 rows now converge on the shipped default path,
the previous shipped holdout
cth_like_free_bdy_lasym_smallwas replaced with a convergentlasym=TrueCTH-like fixture built from the stable bundled free-boundary case,the free-boundary DIII-D rows remain the main runtime blockers, but the latest axisymmetric CPU pass reduced
input.DIII-D_lasym_falsefrom about173.82swarmed earlier in the branch to about113.78swarmed in the final runtime matrix on the same host,the full 21-case CPU matrix is still slower than VMEC2000 on this host, so the performance story is now “better vmec_jax defaults and broader convergence,” not “broad CPU wins over VMEC2000.”
So the branch is useful and now converges across the shipped 21-case matrix, but any merge/default decision should be framed around vmec_jax usability and coverage rather than raw same-host VMEC2000 CPU speed.
boundary
R/Zsynthesis and first-derivative synthesis in_sample_external_boundary_arrayscut the representativeinput.cth_like_free_bdyprofile from about60.41stotal wall time to about58.21swhile keeping the direct NESTOR regression tests green,a follow-on experiment that added a final-grid parity polish to the staged 3D accelerated path was rejected because it raised runtime substantially without improving those benchmarked quality numbers.
Representative warmed CPU baseline-vs-optimized points from the current full matrix:
ITERModel:8.86sbaseline vs4.48soptimized,LandremanPaul2021_QA_lowres:29.65sbaseline vs35.02soptimized, withquality_max_rel_rms ~ 4.20e-03,LandremanPaul2021_QA_reactorScale_lowres:47.17sbaseline vs45.52soptimized, withquality_max_rel_rms ~ 6.42e-04,LandremanPaul2021_QH_reactorScale_lowres:53.28sbaseline vs53.05soptimized, withquality_max_rel_rms ~ 6.00e-05,LandremanSenguptaPlunk_section5p3_low_res:15.86sbaseline vs7.48soptimized,up_down_asymmetric_tokamak:24.58sbaseline vs2.97soptimized.
Representative warmed CPU VMEC2000-vs-vmec_jax points from that historical
readiness matrix:
solovev: VMEC20000.61svsvmec_jax0.09s,circular_tokamak_aspect_100:2.99svs0.59s,cth_like_fixed_bdy:1.34svs1.04s,LandremanPaul2021_QA_reactorScale_lowres:39.27svs44.48s,LandremanPaul2021_QH_reactorScale_lowres:46.34svs53.93s,DIII-D_lasym_false:19.80svs113.78s,cth_like_free_bdy:1.79svs6.96s.
Same-host CPU/GPU reassessment on a reference GPU workstation is now complete for the same 16-case bundled fixed-boundary set:
both CPU and GPU converge on all 16 cases,
GPU is already faster on the heavier 3D cases (
LandremanPaul2021_QA_lowres,LandremanPaul2021_QA_lowres1,LandremanPaul2021_QA_reactorScale_lowres,LandremanPaul2021_QH_reactorScale_lowres,cth_like_fixed_bdy),CPU still wins on the remaining 11 smaller or more launch-latency-dominated cases.
Representative same-host CPU/GPU warmed comparisons:
LandremanPaul2021_QA_reactorScale_lowres: CPU23.72svs GPU13.06s,LandremanPaul2021_QH_reactorScale_lowres: CPU31.39svs GPU16.56s,cth_like_fixed_bdy: CPU1.45svs GPU0.89s,circular_tokamak: CPU0.70svs GPU1.97s,solovev: CPU0.16svs GPU0.55s.
That makes the current branch state clearer than the earlier stress-case benchmark story:
the optimized fixed-boundary CLI path is a mixed but useful warmed CPU result: faster on 13 of 16 rows versus the prior/default branch path, roughly neutral on 1, and slower on 2,
relative to VMEC2000, the current README-facing CPU matrix still wins on only 1 of 16 bundled fixed-boundary rows,
the GPU path is functional and convergent on the same bundled matrix, but is faster only on selected heavier 3D rows; do not claim uniform GPU superiority.
Exact-optimizer replay profiling on f0225ff¶
The exact optimizer has a different performance profile from a standalone
fixed-boundary solve because accepted-point Jacobians replay a differentiated
VMEC tape for every active boundary direction. A fresh 2026-05-12 QH
max_mode=2 profile used
tools/diagnostics/profile_exact_optimizer.py --callback jacobian --repeats 2
--perturb-scale 1e-4 on CPU and on an RTX A4000 GPU, then compared the JSON
reports with tools/diagnostics/compare_profile_reports.py.
The local CPU report took 11.29 s for two new Jacobian points. The GPU
report took 54.83 s for the same two points, with the first GPU point
dominated by cold accepted-point setup and the second warm point taking
4.44 s. The GPU/CPU ratio was 4.86x overall and 4.87x for replay
time; RSS was similar. The largest profile terms were:
CPU:
jacobian_tape_replay4.89 s(43%),jacobian_initial_tangents2.75 s(24%),jacobian_residual_tangents2.21 s(20%).GPU:
jacobian_tape_replay23.79 s(43%),exact_solve_with_tape_total12.89 s(24%),exact_tape_build11.34 s(21%).
This narrows the GPU lane: optimizing residual convergence alone will not fix optimization runtime. The next GPU work should target accepted-point tape replay and tangent batching/reuse, while keeping the CPU exact path unchanged.
Backend-adaptive replay bucketing¶
The first concrete GPU fix from that profile is backend-adaptive dynamic tape
bucketing. CPU replay remains fastest with the previous bucket size 32.
On the office RTX A4000, the same dense exact-Jacobian profile is much
faster when GPU replay pads dynamic tapes to bucket size 128. This is now
the default for CUDA/ROCm/GPU backends; users can still override it with
VMEC_JAX_DYNAMIC_REPLAY_BUCKET for experiments.
Validation used the same QH max_mode=2 accepted-point profile as above:
previous GPU default:
54.83 sfor two perturbed Jacobian points,jacobian_tape_replay=23.79 s;backend-adaptive GPU default:
15.79 sfor two perturbed Jacobian points,jacobian_tape_replay=5.43 s;CPU baseline remains
11.29 sfor the same profile, so the optimized GPU profile is now about1.40xCPU overall and1.11xCPU in replay time.
This does not make GPU universally faster yet. It removes the largest replay regression for small/medium exact optimizations and leaves tape construction, residual tangent projection, and accepted-point replay fusion as the next GPU targets.
Follow-up profiling on 2026-05-12 confirmed the same bottleneck split. A tiny
cold fixed-boundary smoke case on office still favored CPU
(13.80 s CPU versus 68.86 s GPU), so cold single solves remain a CPU
use case. A local CPU exact-Jacobian smoke spent about 96% of callback time
in replay/tangent work: jacobian_tape_replay 3.34 s,
jacobian_initial_tangents 2.85 s, and
jacobian_residual_tangents 2.75 s out of a 9.31 s callback. The
next GPU and CPU target is therefore accepted-point tangent fusion/reuse, not
ordinary force-kernel optimization.
The 2026-05-18 scalar-gradient patch narrows the initial tangent/projection
cost without changing the dense least-squares Jacobian path:
objective_and_gradient_fun now projects reverse tape cotangents through an
initial-state VJP on tangent-cache misses, caches that VJP by the same
shape/flip-branch key, and still reuses the dense tangent map when a previous
Jacobian callback already built it. On a local CPU QH max_mode=2 bounded
profile with two perturbed gradient callbacks
(inner_max_iter=40, trial_max_iter=20), total callback time dropped from
20.47 s to 18.68 s; the warm repeat dropped from 0.69 s to
0.56 s. A bounded QH max_mode=1 check with inner_max_iter=8
matched the dense reference gradient (||g_adj - J.T r|| = 2.10e-12,
relative 7.21e-14).
The 2026-05-25 office GPU follow-up rechecked the current QH
max_mode=2 exact-optimization hotspot with a deliberately short
inner_max_iter=10 accepted solve. The dense least-squares Jacobian still
materializes a 44353 x 24 residual derivative, so the one-callback GPU
profile is dominated by accepted-point tape construction and residual-tangent
projection:
Path |
Wall time |
Main profile terms |
|---|---|---|
Dense Jacobian |
|
|
Forced projected replay |
|
|
Scalar adjoint, 4 optimizer callbacks, old backtracking |
|
no accepted improvement from the initial objective |
Scalar adjoint, L-BFGS/fast backtracking, 4 callbacks |
|
objective |
Scalar adjoint, L-BFGS/fast backtracking, 8 callbacks |
|
objective |
Scalar adjoint, default radius plus best-state finalization cache |
|
objective |
Scalar adjoint, radius re-expansion after backtracked accepts |
|
objective |
Scalar adjoint, opt-in cost-only trial filter |
|
objective |
Matrix-free SciPy, |
|
objective |
That historical test ruled out the earlier projected-replay implementation for
24 DOFs. After the JVP-only/basepoint-carry tape changes, the updated bounded
QH mode-2 comparison slightly favors projected replay for non-LASYM GPU
callbacks, while SciPy’s generic matrix-free trust-region path still spends too
much time in repeated linear-operator products at low LSMR budgets. In this QH
mode-2 GPU microcase, the scalar-adjoint path remains useful because it avoids
dense Jacobian materialization and now accepts monotone steps after aggressive
backtracking. Keeping the best exact
scalar state also removes one final accepted-point solve for runs whose last
probe is rejected.
Backtracked accepts now re-expand the next trust radius to the previous rejected
scale instead of permanently collapsing to the accepted microscopic step. A
cost-only trial filter was also tested, but it is not the default because this
QH GPU case spends about 2.2 s per trial solve and is slower overall despite
halving exact-gradient builds. It remains available for diagnostics with
scalar_cost_only_trials=True on FixedBoundaryExactOptimizer.run (or the
same argument through least_squares_solve). The legacy environment variable
VMEC_JAX_OPT_SCALAR_COST_ONLY_TRIALS=1 remains available for profiling
scripts, and tools/diagnostics/profile_exact_optimizer.py exposes the same
switch as --scalar-cost-only-trials. The path still needs problem-aware
step proposal/globalization before it can replace dense SciPy least-squares for
README-quality optimization results.
Additional controller finding from March 2026:
the existing fully non-VMEC scan path was re-probed as a possible next accelerated controller, but it is not yet robust enough to become the default accelerated path: on representative fixed-boundary cases it is much faster, but it can diverge badly in
fsq_totaland finalwoutquality. The current accelerated mode therefore stays on the masked VMEC-control scan until a more stable device-resident controller is in place.
If you want an automatic parity probe when using scan, set:
export VMEC_JAX_SCAN_PARITY_GUARD=1
This runs a short scan-vs-non-scan probe at the start of each stage and falls back to the non-scan loop if a mismatch is detected. It is off by default because it adds extra compilation and iteration overhead.
Scan chunking (fixed NSTEP blocks)¶
To avoid retracing for variable tail lengths, the scan loop executes in fixed
chunks of length NSTEP (the VMEC input parameter). Iterations beyond
NITER are masked by the in-scan hold condition, so the extra work is a
no-op and does not affect parity.
Controls:
VMEC_JAX_VMEC2000_CHUNKED=1(default): enable chunked scan.VMEC_JAX_SCAN_CHUNK_SIZE=<int>: override chunk length (defaults toNSTEP).
This reduces compilation cache misses when the stage transition changes
NITER but keeps the same NSTEP cadence.
Live NSTEP printing (debug callback)¶
VMEC2000-style iteration rows (scan and non-scan) are printed using a JAX debug callback by default. This keeps the output VMEC-like without inserting extra Python-side synchronization on every step.
Defaults:
Live printing is enabled when
verboseandvmec2000_controlare on.The VMEC2000-control scan path defaults to
debug_callbackfor ordered host-side row formatting.debug_printremains available for simpler device-side debug printing.
Disable live printing with:
export VMEC_JAX_SCAN_PRINT=0
If you want to reduce any remaining host-callback overhead, increase NSTEP
in the input file. Fewer prints means fewer callbacks.
Quiet scan runs (--quiet / verbose=False) automatically switch to a
minimal history mode: only fsqr/fsqz/fsql (and therefore w_history) are
kept. Per-iteration print scalars (r00, w_mhd) and time-step histories
are skipped to reduce host/device traffic. Override with:
export VMEC_JAX_SCAN_MINIMAL=0 # keep full scan diagnostics
export VMEC_JAX_SCAN_MINIMAL=1 # force minimal histories
In fast mode (performance_mode=True / --fast), scan_minimal is the
default for quiet runs (verbose=False) unless explicitly overridden by
VMEC_JAX_SCAN_MINIMAL. When verbose=True, scan keeps the extra scalar
histories needed for VMEC-style printing.
Advanced knobs (not required for normal use):
VMEC_JAX_SCAN_PRINT_MODE=debug_callback(default for VMEC2000-control scan)VMEC_JAX_SCAN_PRINT_MODE=debug_print(device-side debug printing)VMEC_JAX_SCAN_PRINT_ORDERED=1to force ordered prints (may reduce parallelism)
DFT tomnsps (GEMM path)¶
VMEC2000’s tomnsps analysis transform is now implemented as a two-stage
DFT using the precomputed fixaray trig/weight tables:
theta stage: multiply by
cosmui/sinmui(endpoint-weighted +mscale),zeta stage: multiply by
cosnv/sinnv(withnscaleandn*NFPincosnvn/sinnvnfor derivative terms).
The core contractions are done with batched dot_general calls so XLA can
lower them into GEMM kernels. This follows the VMEC++ basis approach (see
References [5-6]) while keeping VMEC2000 parity.
Recent updates to the DFT path:
Stacked theta contractions: multiple force kernels are concatenated into a single cosine and sine projection per iteration, reducing the number of
dot_generallaunches.Derivative-factor fusion: the \(n\,\mathrm{NFP}\) factor for
cosnvn/sinnvnis applied after the zeta contraction, so the samecosnv/sinnvbasis can be reused for derivative blocks.Stacked zeta contractions: cosine- and sine-basis transforms for the derivative and non-derivative blocks are grouped to reduce kernel dispatches.
NumPy force path follows the same policy: the host CPU force shim no longer forces the FFT branch. It uses the DFT/GEMM path by default and still honors
VMEC_JAX_TOMNSPS_FFT=0/1. On the finite-betainput.nfp4_QH_finite_betaCPU diagnostic, this reduced a 200-iteration force block from1.067 sto0.868 swhile preserving the residual trajectory, and the full 3000-budget direct solve completed in13.05 swith10.28 sin force assembly.
An FFT-based path remains available for experiments:
VMEC_JAX_TOMNSPS_FFT=1enables the FFT implementation (not default).
Preconditioner weight caching¶
The 1D radial preconditioner uses angular weights
\(w_i=\mathrm{cosmui3}_{i,0}/\mathrm{mscale}_0\) on the VMEC internal grid.
These weights are now cached in the trig table as wint3_precond and reused
whenever the preconditioner diagonal is refreshed. This avoids rebuilding the
same weight tensor in every refresh call and keeps the preconditioner refresh
path purely algebraic in bsq, r12, sqrtg, ru12, and zu12.
Free-boundary WP1 micro-benchmark¶
For free-boundary staging, use the dedicated benchmark script:
python tools/benchmarks/bench_free_boundary_wp1.py \
--input examples/data/input.DIII-D_lasym_false \
--interp-points 20000 \
--interp-repeats 5
This reports:
metadata validation/load time,
full mgrid tensor load time,
interpolation throughput and sampled
|B_ext|stats.
Solver note: this benchmark isolates external-field staging cost. The sampling toggle below is diagnostic-only; it does not describe overall free-boundary solver maturity. You can disable that sampling with:
export VMEC_JAX_FREEB_SAMPLE_EXTERNAL=0
WP2 free-boundary runtime controls¶
Current free-boundary coupling defaults to the VMEC2000-aligned dense vacuum operator described in Free-Boundary Plan. The earlier lightweight spectral solve remains useful for historical context and debugging, but it is not the promoted parity path. To keep dense-coupling runtime bounded:
mgrid field tensors are cached by path in-process (avoids per-iteration NetCDF reloads),
Poisson spectral denominators are stage-static,
ivacskipreuses prior potential (skip solve) whenivac != 1.
Control flags:
export VMEC_JAX_FREEB_COUPLE_EDGE=1 # default: on
export VMEC_JAX_FREEB_SAMPLE_EXTERNAL=1 # default: on
If profiling free-boundary solver-only cost, disable sampling diagnostics:
export VMEC_JAX_FREEB_SAMPLE_EXTERNAL=0
Direct-coil CPU/GPU micro-benchmark snapshot¶
The direct-coil benchmark matrix isolates the new coil-provider free-boundary path from the broader example matrix:
python tools/benchmarks/bench_freeb_direct_coil_matrix.py \
--quick --include-gpu \
--out /tmp/freeb_matrix_office_gpu_followup/summary.json
On the 2026-05-28 office CPU/CUDA rerun, the matrix used concrete-platform
probing so CUDA rows were recorded even under JAX_PLATFORMS=cpu,cuda. The
best bounded direct-solve row was direct_solve_jit_forces. The warm solve
was still CPU-favorable (0.0525 s CPU versus 0.2346 s CUDA,
4.46x GPU/CPU), but the detailed timers show that the remaining GPU cost is
no longer the direct Biot-Savart field sample or the dense NESTOR solve:
Timing bucket |
CPU |
CUDA |
CUDA/CPU |
|---|---|---|---|
setup |
|
|
|
residual metrics |
|
|
|
accepted-control fsq1 |
|
|
|
preconditioner apply |
|
|
|
force evaluation |
|
|
|
finalize |
|
|
|
The next GPU optimization target is therefore accepted-control and preconditioner/update dispatch amortization plus reusable setup/precompute state, not Biot-Savart sampling or the final dense vacuum solve. Scalar-defer is deliberately not the default yet because the residual scalars still drive VMEC control flow, history, and free-boundary acceptance logic.
The follow-up instrumentation splits setup_total_s into static-grid
rebuild, free-boundary policy, boundary/profile construction, cache-key hashing,
ptau constants, mode-index constants, and update constants. Use those keys
in warm_solver_timing.timing or in the matrix cpu_gpu_comparison block
before implementing the next setup cache; the split is intentionally
measurement-only and does not change solver numerics.
When VMEC_JAX_TIMING=1 is enabled, the direct-coil solve benchmark also
emits a normalized phase_timing_comparison block. It ranks the warm setup,
force-evaluation, residual-metric, accepted-control fsq1, preconditioner,
update, and finalize buckets and reports cold-to-warm improvement for each
bucket. This is the preferred quick triage view for structural control-loop
staging and preconditioner/residual-scalar dispatch work because it is stable
even when the raw timing dictionary grows new sub-buckets.
When both CPU and GPU rows are present, the matrix summary also writes
gpu_bottleneck_summary. This is a short, sorted list of warm phases for
which GPU is slower than CPU, including the GPU/CPU ratio and absolute
gpu_minus_cpu_s. Use this field as the first-pass triage table before
opening detailed XLA or Perfetto traces.
The first CUDA probe after adding that split reported a warm setup_total_s
of 40.4 ms on the tiny direct-coil case. The dominant setup sub-buckets
were boundary/profile construction (18.6 ms), update constants
(12.3 ms), and unattributed setup (5.9 ms). Cold setup was dominated
by compilation-adjacent boundary/profile and axis-reset work, so cold-start
analysis should be kept separate from warm accepted-point optimization timing.
Accelerator host-forward runs now use the existing NumPy row-enforcement setup
path for the initial state when VMEC_JAX_HOST_SETUP_ENFORCE=auto. This
avoids tiny eager device dispatches in setup without changing traced
differentiable solves; set the variable to 0 to disable the policy or 1
to force it in non-traced profiling runs. On the follow-up office CUDA probe,
warm time improved from 0.180 s to 0.169 s and
setup_update_constants_s dropped from 12.3 ms to 4.7 ms. The
remaining warm GPU targets are boundary/profile setup, residual scalar
materialization, accepted-control fsq1, and preconditioner dispatch.
The next pushed patch extends the host flux-profile fast path to concrete
default-APHI iota profiles, so the non-traced setup path no longer falls
back to small eager JAX profile work for common AI inputs. A local CPU
quick benchmark at that head reported the tiny direct-coil --jit-forces
warm solve at about 0.026 s with setup_boundary_profiles_s≈1.5 ms.
The matched office CPU/CUDA matrix still showed the GPU row as
CPU-favorable (0.0521 s CPU versus 0.2318 s CUDA), with force assembly
near parity (9.11 ms CPU versus 9.68 ms CUDA). This confirms the next
larger performance patch should cache or stage the non-traced setup/control
payload rather than continuing to optimize the direct coil field kernel.
The follow-up host-profile setup policy adds
VMEC_JAX_HOST_PROFILE_SETUP=auto (0 disables, 1 forces) so
non-traced accelerator setup can also use the host profile path. On the same
office quick matrix the tiny --jit-forces direct-coil row improved to
0.0552 s warm on CPU and 0.1625 s warm on CUDA (2.95x GPU/CPU).
The GPU setup/profile bucket dropped to 5.6 ms and force assembly stayed
slightly faster than CPU (8.28 ms CUDA versus 9.98 ms CPU). The
remaining named GPU buckets are residual scalar materialization
(18.6 ms), accepted-control fsq1 (12.8 ms), and preconditioner
dispatch (10.8 ms).
The next promoted accelerator-forward policy stages the preconditioned
fsq1 norm reductions on the host for non-traced accelerator solves:
VMEC_JAX_HOST_FSQ1_NORMS=auto (0 disables, 1 forces). The policy
is disabled for traced/AD solves, so exact-gradient paths keep the JAX-native
reduction. Later office policy-ablation runs showed that host-staging the
primary residual metrics was not robustly beneficial: the leaner device-product
path measured 0.181 s warm on CUDA versus 0.224 s for the old
host-residual default in the tiny direct-coil JIT-forces row. Therefore
VMEC_JAX_HOST_RESIDUAL_METRICS=auto now keeps residual products on device
and only materializes the three residual scalars; set the variable to 1 to
force the older six-scalar host-staging path for diagnostics.
The production-like timing-light row, which disables detailed timing
synchronization, moved in the same direction: the tiny --jit-forces direct
coil row measured 0.0528 s warm on CPU and 0.1857 s warm on CUDA
(3.52x GPU/CPU). This confirms the policy is not merely improving a
diagnostic bucket, but the tiny case is still launch/synchronization dominated.
Do not claim GPU speedup for this row; use it as the regression target for the
next residual/preconditioner dispatch fusion pass.
A 2026-05-28 fresh-clone office rerun with the normalized
phase_timing_comparison confirmed the same target after the latest merge:
CPU warm time was 0.0424 s and CUDA warm time was 0.1466 s for the tiny
direct-coil --jit-forces row. The CPU named phases were setup
(14.5 ms), force evaluation (4.90 ms), residual metrics
(0.64 ms), and preconditioner (0.55 ms). The CUDA named phases were
residual metrics (23.8 ms), preconditioner (20.5 ms), setup
(20.2 ms), and force evaluation (10.0 ms). This keeps the next
structural performance lane focused on residual scalar synchronization,
preconditioner dispatch/application fusion, and reusable setup context; the
direct-coil field kernel itself is not the limiting bucket for this case.
A follow-up same-head office matrix on PR head 599668a added the
--include-timing-light row to separate production-like wall time from
timing synchronization. The tiny direct-coil --jit-forces row measured
0.0569 s warm on CPU and 0.1898 s warm on CUDA with detailed timing
enabled. The timing-light row measured 0.0719 s warm on CPU and
0.1462 s warm on CUDA. The detailed CUDA buckets again showed force
evaluation near CPU parity (10.7 ms versus 10.4 ms) while residual
metrics (20.1 ms), preconditioner dispatch/application (15.7 ms), and
setup (29.1 ms) dominate. The performance lane is therefore still
structural control-loop staging and preconditioner/residual-scalar fusion, not
Biot-Savart evaluation.
Two opt-in policies were checked and are deliberately not promoted. Allowing
the host-update path on accelerators with VMEC_JAX_HOST_UPDATE_ON_ACCELERATOR=1
made the tiny CUDA row slower in this matrix, and setting
VMEC_JAX_BADJAC_INITIAL_STATE_PROBE_ITERS=0 did not produce a robust
speedup after the accepted-control payload fusion. --include-timing-light
rows also showed that production-like wall time remains GPU-slower, so the
remaining work is structural: stage or fuse the VMEC-control residual scalar
materialization, accepted-control fsq1, and preconditioner dispatch.
The 2026-05-31 post-merge timing split added correctness-critical finalize
sub-buckets to the same benchmark matrix. A local CPU quick run with
--include-badjac-probe0 --include-timing-light reported a warm direct-coil
solve_total_s=0.1184 and finalize_s=0.00876; finalize split into
finalize_nestor_recompute_s=0.00651,
finalize_residual_recompute_s=0.00218, and
update_state_ready_s=4.25e-6. A fresh office CPU/CUDA clone at
bc00ff4 completed all 22 --include-gpu --include-policy-ablation rows.
The default CUDA direct solve remained GPU-slower (9.42x CPU), dominated by
non-JIT force evaluation (0.602 s), preconditioner work (0.369 s),
setup/axis reset (0.334 s), and final residual recompute (0.296 s).
With JIT forces enabled, force evaluation became faster than CPU (0.82x),
but the warm row was still 2.48x CPU because preconditioner refresh/apply
(0.0438 s, 2.55x CPU) and residual/control/finalization dispatch
remained exposed.
The policy-ablation rerun after the direct-coil replay-helper commits reached
the same conclusion. All CPU/GPU rows completed, but disabling host residual
metrics, host fsq1 norms, host profile setup, or all three together did not
produce a single promoted GPU fix. The tiny non-JIT-force direct solve remained
about 10.48x CPU, and JIT-force policy-ablation rows remained
2.65x–3.08x CPU. Treat the current benchmark evidence as a regression
target for controller/preconditioner/finalization staging, not as a GPU speedup
claim for the direct-coil free-boundary row.
The 2026-06-01 PR-head rerun at 8a0ce5c2 confirmed the bounded
preconditioner reuse patch removes the duplicate refresh after a
bcovar-seeded preconditioner build. A fresh office clone completed all
22 CPU/GPU quick rows and the timed direct rows reported
precond_refresh_seed_reuse_count=1. On CPU, the JIT-force direct solve
reported warm_min=0.0677 s and solve_total=0.0567 s. On CUDA, the
same row reported warm_min=0.183 s and solve_total=0.116 s:
JIT-force evaluation was faster than CPU (0.88x), finalization was
essentially parity (0.98x), and the remaining warm GPU penalties were the
seeded preconditioner refresh (2.01x), preconditioner apply (18.8x),
accepted-control scalar work (4.09x), and residual-metric synchronization
(43.8x). The best policy-ablation row was
direct_solve_jit_forces_host_policies_off at 2.43x CPU. The next
performance target is therefore not Biot-Savart or force assembly; it is
structural staging/fusion of the preconditioner apply and residual/control
scalar synchronization.
A 2026-06-05 follow-up policy-ablation matrix kept the same conclusion but
changed one default: host-staged residual metrics are no longer the accelerator
auto path. The tiny direct-coil JIT-forces row measured 0.224 s warm
on CUDA with the old host-residual policy, 0.181 s when residual products
were kept on device, and 0.250 s/0.230 s for the tridiagonal
precompute/lax-solve variants. VMEC_JAX_HOST_RESIDUAL_METRICS=1 remains
available as a diagnostic override, but the production target is now reducing
preconditioner/update/finalization launch overhead.
The same benchmark family now writes a final_recompute_guard block for
direct-solve children. This is deliberately a benchmark-only correctness
guard, not a solver shortcut: it records final accepted-state residuals,
pre-update residuals, their maximum absolute delta, final NESTOR metric deltas,
and the measured finalize recompute cost. A quick CPU probe after the guard
was added reported nonzero final/pre-update residual deltas, so final NESTOR
and final residual recomputes remain required by default. Any future cached
finalization path must first make this guard pass with parity evidence before
changing production behavior.
Historical bundled example runtime/memory matrix (March 2026)¶
For repeatable runtime/memory sweeps across the bundled inputs, use:
python tools/diagnostics/example_runtime_memory_matrix.py \
--backend both \
--vmec-exec /path/to/xvmec2000
Recent artifacts from this tool:
outputs/example_runtime_memory_matrix_cpu_20260306/summary.json: all bundled examples on a reference CPU host, including VMEC2000 timings.outputs/example_runtime_memory_matrix_gpu_20260306_summary.json: all bundled examples on a reference CUDA host (CUDA JAX).outputs/example_runtime_memory_matrix_gpu_freeb_20260306_rerun_summary.json: corrected GPU rerun for the bundled free-boundary cases after staging bundledmgridfiles in the benchmark clone.
Historical snapshot highlights from this March 2026 matrix:
Fixed-boundary scan performance on the reference GPU host improved materially after the accelerator-aware scan probe and quiet-scan chunking changes:
input.circular_tokamaknow runs in about13.8s/1.97 GiB.input.LandremanPaul2021_QA_lowresnow runs in about33.9s/2.66 GiB.input.up_down_asymmetric_tokamaknow runs in about16.5s/1.60 GiB.input.basic_non_stellsym_pressurenow runs in about141.1s/3.68 GiB.input.LandremanSenguptaPlunk_section5p3_low_resnow runs in about77.1s/2.13 GiB.
Fixed-boundary
lasym=Trueon the reference CPU host remains:input.up_down_asymmetric_tokamakabout6.7s/0.89 GiBversus VMEC2000 about0.74s.input.basic_non_stellsym_pressureabout29.7s/3.22 GiBversus VMEC2000 about2.02s.input.LandremanSenguptaPlunk_section5p3_low_resabout46.8s/4.07 GiBversus VMEC2000 about0.69s.
Bundled free-boundary cases remain the dominant default-path outliers:
input.DIII-D_lasym_false: about428.2s/7.36 GiBon the reference CPU host, about1602.3s/6.23 GiBon the reference GPU host, versus VMEC2000 about14.4s.input.cth_like_free_bdy: about41.8s/1.64 GiBon the reference CPU host, about155.8s/2.30 GiBon the reference GPU host, versus VMEC2000 about2.48s.input.cth_like_free_bdy_lasym_small: about37.6s/1.47 GiBon the reference CPU host, about103.5s/1.97 GiBon the reference GPU host, versus VMEC2000 about0.63s.
Recent parity-path free-boundary GPU work narrowed the large-
nsforce-kernel overhead:deferring non-scan scalar-history materialization was effectively neutral on the smaller
input.cth_like_free_bdycase (about111.3swarm on70fc418versus about111.4swarm onf35ce44).passing only the free-boundary
bsqvacedge slice into the force kernel instead of rebuilding a mostly-zero(ns, ntheta, nzeta)array every iteration materially improves the heavy axisymmetric case. On the reference GPU host, a parity-pathmax_iter=10probe ofinput.DIII-D_lasym_falsedropped:compute_forcesfrom about5.79stotal (0.579s/iteron70fc418) to about2.58stotal (0.258s/iteronf35ce44),preconditionerfrom about0.675sto about0.324s,updatefrom about0.914sto about0.535s.
The current GPU path is not yet a universal speedup:
same-host CPU/GPU benchmarking shows GPU is already faster on the heavier QA/QH reactor-scale cases, but not yet on the smallest axisymmetric cases,
the current mixed result is now backend-selection limited rather than convergence limited: all 16 bundled fixed-boundary cases converge on both CPU and GPU.
Why the GPU can still be slower than the CPU¶
This is a consequence of the current solver architecture, not a statement that VMEC-like equilibria are fundamentally better suited to CPUs. The short version is:
the fast path is the scan-lifted path, where JAX can keep long stretches of work on-device,
the parity path is still a host-controlled VMEC2000-style iteration,
many of the slowest benchmark rows are exactly those parity-path solves, especially free-boundary cases.
In more detail:
VMEC2000 parity requires a host-controlled nonlinear loop
The conservative path preserves VMEC2000-style semantics such as:
Garabedian time-step control,
Jacobian sign checks,
same-iteration restarts,
free-boundary
ivac/ivacskip/nvacskipcadence,per-iteration diagnostics and VMEC-style tables,
stage transitions and cache refresh rules.
In the current implementation, those decisions still happen in Python on the host. Each iteration therefore launches several short JAX kernels, waits for scalar decisions, then launches the next block. CPUs tolerate that control pattern much better than GPUs because the launch/synchronization cost is smaller.
The kernels are mostly moderate-size float64 kernels, not giant batched GPU kernels
For parity we run in float64, matching VMEC2000 numerics. On many of the shipped examples the per-iteration grids are only moderate in size, so the GPU never reaches the kind of occupancy that would amortize launch overhead. The work is also heavy in transforms, synthesis, and tensor assembly (bcovar, tomnsps, force kernels), which are often memory-traffic bound rather than one large dense GEMM.
The result is that the CPU can look surprisingly competitive, because it is executing the same float64 algebra with lower orchestration overhead and without paying for many small host->device transitions.
Free-boundary parity is the worst case for the current GPU stack
Free-boundary adds more than just one extra kernel. It adds:
external/vacuum field staging,
extra edge-force coupling,
free-boundary reuse/refresh cadence,
more restart-sensitive control flow,
larger edge/state tensors on some axisymmetric cases.
The timing probes in this repo show that on the current parity free-boundary GPU path,
compute_forcesdominates. For example, on the reference GPU host:input.cth_like_free_bdywithperformance_mode=Falsespends about0.278s/iterincompute_forces, while preconditioning and update are much smaller.input.DIII-D_lasym_falseis even more sensitive to force-path data movement because of its largens.
Data movement and edge-coupling details matter a lot on large free-boundary cases
Recent profiling made this explicit. Passing only the free-boundary
bsqvacedge slice into the force kernel, instead of rebuilding a mostly zero(ns, ntheta, nzeta)array each iteration, was nearly neutral on the smallerinput.cth_like_free_bdycase but materially improved the large axisymmetric case. On a parity-pathmax_iter=10probe ofinput.DIII-D_lasym_falseon the reference GPU host:compute_forcesdropped from about0.579s/iterto about0.258s/iter,preconditionerdropped from about0.067s/iterto about0.032s/iter,updatedropped from about0.091s/iterto about0.054s/iter.
That is a good example of the current situation: the GPU is not losing because of the physics model itself, but because the parity path still contains control-flow and data-shaping patterns that are cheap on CPU and expensive on GPU.
Compilation and warmup amplify the gap on short runs
JAX/XLA compile cost is front-loaded. On short solves, or on runs that only execute a small number of iterations per stage, compile and cache warmup can dominate the wall time. This hurts accelerator results more than CPU results because the GPU path has higher startup overhead and stricter sensitivity to retracing/recompilation.
Differentiability and parity constraints limit aggressive GPU-only shortcuts
vmec-jaxis not trying to be a separate non-parity GPU solver. We are preserving:end-to-end differentiability,
VMEC2000-compatible iteration behavior where parity is required,
VMEC-style outputs and diagnostics.
That rules out some easy GPU wins that would change ordering, skip diagnostics, or replace the parity controller with a different nonlinear algorithm. The current performance work is therefore focused on moving more of the existing algorithm into longer device-resident regions without changing the numerical contract.
When the GPU already helps today
The GPU story is already much better when the solve can remain on the fast scan path, or when repeated runs can amortize compile cost. That is why the fixed-boundary scan cases improved materially after:
accelerator-aware scan probing,
larger quiet-scan chunks,
reduced launch overhead in the scan path.
The next large gains on GPU are therefore expected to come from the same direction on the parity/free-boundary side and the accepted-point replay path: keeping more of the force/residual/control pipeline on-device for longer stretches, and reducing per-iteration host orchestration.
Fused tridiagonal solver (scan only)¶
The scan preconditioner can use XLA’s fused tridiagonal solver with
pretransposed coefficients (dl/d/du). Current bounded profiles show this
is the fastest measured CPU scan path. GPU exact-optimization trial callbacks
now default to the non-scan loop; when scan is forced for GPU diagnostics, the
older scan preconditioner with precomputed Thomas coefficients remains the
measured faster scan variant. Keep the explicit env overrides for
parity/perf bisection:
VMEC_JAX_SCAN_PRECOND_LAXTRIDI=0: force the older Thomas scan path.VMEC_JAX_SCAN_PRECOND_LAXTRIDI=1: explicitly request the fused scan path.VMEC_JAX_SCAN_PRECOND_PRECOMPUTE=0: disable the GPU scan precomputed Thomas-coefficient default.
Boundary decomposition cache + JAX-friendly initial guess¶
boundary_from_indata now caches the boundary decomposition across runs
using the input file path + mtime (or a coefficient fingerprint when the path
is unavailable). This trims repeated host work in workflows that solve the same
input file multiple times in a single process.
The initial-guess path also supports a fully JAX-backed boundary flip and constraint application, which reduces Python-side overhead and keeps the path JAX-friendly for future JIT staging. Control this with:
VMEC_JAX_INIT_GUESS_JAX=1(default): use JAX boundary flip path.VMEC_JAX_INIT_GUESS_JAX=0: fall back to NumPy/Python boundary flips.
Implementation map (performance-critical paths)¶
vmec_jax/vmec_tomnsp.py: VMECfixaraytables + DFT-basedtomnsps.vmec_jax/init_guess.py: initial guess, axis blending, JAX boundary flip.vmec_jax/boundary.py: input boundary decomposition + cache.vmec_jax/static.py: cached grids, phase stacks, and per-solve constants.
Recent profiling snapshot (QA, 3 iterations on CPU)¶
Default loop: ~0.26s total wall time (post-warmup).
Scan loop: ~0.083s total wall time (~3x faster for this short run).
Longer runs benefit more because Python control-flow overhead scales with the iteration count in the non-scan path.
VMEC++ bad-progress restarts (optional)¶
VMEC++ introduces a “bad progress” restart policy that detects large residuals
on refined grids and restarts the time-step controller more aggressively. This
is now available in vmec_jax behind an explicit flag so the VMEC2000 parity
path remains unchanged by default.
The VMEC++-style trigger follows the VMEC++ criteria:
iter2 - iter1 > k_preconditioner_update_interval / 2iter2 > 2 * k_preconditioner_update_intervalfsqr + fsqz > 1e-2(physical residual on the full grid)
When triggered, the restart path reduces delt by 1/1.03 (the VMEC++
“bad progress” factor) and resets the cached preconditioner state.
Enable it with:
run_fixed_boundary(..., vmecpp_restart=True)
Note: the VMEC++ restart flag is currently wired to the VMEC2000-control path. When scan is active, it takes effect on fallback segments that execute in the non-scan parity controller.
Static precomputation¶
Use VMECStatic to avoid rebuilding:
mode tables,
angle grids,
Fourier basis tensors,
radial grid.
VMEC phase-stack cache¶
The VMEC real-space synthesis path builds full (m,n) phase tables from the
fixaray trig tables. This is correct but expensive to repeat inside the JIT
kernel. VMECStatic now precomputes and caches stacked phase tensors for the
VMEC grid (including dtheta/dzeta variants) and attaches them to the
cached trig tables. The precompute uses NumPy on the host to avoid extra JAX
compilation work. This reduces both runtime and compilation work because the
kernel no longer rebuilds the phase tables from scratch every iteration.
Control this behavior with:
VMEC_JAX_CACHE_VMEC_PHASE=1(default): precompute phase stacks inbuild_staticfor fastest execution.VMEC_JAX_CACHE_VMEC_PHASE=0: skip the extra cached tensors to save memory.
Compilation cache¶
JAX can persist compiled executables to disk. vmec_jax enables this
automatically for accelerator-selected runs because cold-start compilation
dominates short fixed-boundary and optimization diagnostics, especially on GPU.
CPU cache use is explicit opt-in with VMEC_JAX_COMPILATION_CACHE=1 because
XLA:CPU AOT cache hits can emit host-feature mismatch errors on some JAX
versions. By default, the cache directory is machine/CPU-feature/Python/JAX
scoped so shared home directories do not reuse incompatible CPU AOT artifacts
across different hosts or runtime versions. Use
VMEC_JAX_COMPILATION_CACHE_DIR=/path/to/cache (or the upstream
JAX_COMPILATION_CACHE_DIR) to choose the cache location, or set
VMEC_JAX_COMPILATION_CACHE=0 to disable it.
If XLA prints a message such as Loading XLA:CPU AOT result followed by a
target-machine feature mismatch, it means an existing persistent-cache entry was
compiled for another CPU. Do not suppress that message blindly: clear the old
cache directory or rerun once with VMEC_JAX_COMPILATION_CACHE=0. Current
vmec_jax defaults avoid enabling the CPU persistent cache unless the user
opts in explicitly or provides a cache directory.
CLI profiling (pre-iteration overhead)¶
To capture a JAX trace for the VMEC2000-style CLI path, set
VMEC_JAX_PROFILE_DIR before invoking vmec. By default the CLI also
emits a Perfetto-compatible trace (perfetto_trace.json.gz); disable that
extra file by setting VMEC_JAX_PROFILE_PERFETTO=0. The trace is written in
TensorBoard/Chrome trace format:
VMEC_JAX_PROFILE_DIR=/tmp/vmec_jax_trace \\
vmec examples/data/input.ITERModel --max-iter 3 --no-multigrid --no-use-input-niter --quiet
For tighter windows (e.g., pre-iteration or iter-1 only), set
VMEC_JAX_PROFILE_WINDOW=pre (or iter1 / iterN) and optionally start
a profiler server for XProf inspection:
VMEC_JAX_PROFILE_DIR=/tmp/vmec_jax_trace \\
VMEC_JAX_PROFILE_WINDOW=pre \\
VMEC_JAX_PROFILE_SERVER=1 VMEC_JAX_PROFILE_SERVER_PORT=9999 \\
vmec examples/data/input.ITERModel --max-iter 3 --no-multigrid --quiet
With VMEC_JAX_PROFILE_SERVER=1 you can also capture a tight window using
python -m jax.collect_profile from another terminal (see the JAX profiling
guide for the exact invocation).
Recent traces show that the pre-iteration time is dominated by JIT
compilation/cache misses (pjit cache_miss + backend compile) rather than
the nonlinear iteration itself. This is expected for short runs on CPU.
For repeated CPU or GPU runs, the default compilation cache can significantly
reduce this overhead once the cache is warm. Explicit solver_device="gpu"
requests enable the same accelerator cache policy even when no JAX GPU
environment variable was set before import. Set
VMEC_JAX_COMPILATION_CACHE=0 to disable it.
Persistent compilation cache tuning¶
JAX’s persistent cache can be made more aggressive via vmec_jax environment
variables:
VMEC_JAX_CACHE_MIN_COMPILE_TIME_SECS(default: 0)VMEC_JAX_CACHE_MIN_ENTRY_SIZE_BYTES(default: -1)VMEC_JAX_COMPILATION_CACHE_MAX_SIZE(optional)
These map to JAX’s persistent cache configuration and allow caching more (or
fewer) compiled executables to reduce repeat-start latency for stable shapes.
Enable cache-miss diagnostics by setting
VMEC_JAX_EXPLAIN_CACHE_MISSES=1; JAX will log a short summary whenever a
cache miss triggers a compilation.
Batched radial smoothing¶
The scan path now batches the radial tridiagonal smoother across the R/Z components (and separately for lambda) so the solver does fewer tridi solves per iteration. This reduces kernel count and Python overhead while preserving the VMEC update math.
Batched VMEC real-space synthesis¶
The VMEC-grid synthesis path now batches base + derivative (dtheta/dzeta)
evaluations into a single stacked einsum call. This reduces kernel count in
the pre-iteration setup (especially the bcovar/realspace pipeline) while
preserving the original algebra and parity outputs.
Vectorized multigrid conversion¶
Multigrid staging now uses the vectorized signed↔(m,n) conversion helpers from
vmec_parity instead of Python loops. In the current path the signed→(m,n)
conversion uses precomputed dense maps (matmul) to avoid repeated gather-heavy
indexing. This trims host-side overhead during grid transitions, which shows up
prominently in short profiling traces.
Multigrid interpolation caches¶
Radial interpolation now caches the (j1,j2,xint) weights and scalxc
profiles for reuse across multigrid stages. This reduces host-side setup costs
when multiple grids are visited in a single solve.
Precomputed (m,n)→signed maps¶
The fixed-boundary update now builds dense mapping matrices once per solve to
convert (m,n>=0) force blocks into signed Fourier updates via matmul. This
reduces scatter-heavy updates inside the iteration loop and keeps the JIT graph
more regular.
Batched sin conversions¶
The scan update now batches the Z/L (m,n) sin-block conversions into a
single matmul-based mapping, reducing kernel count compared to converting each
field independently.
Scatter-free boundary/axis enforcement¶
The fixed-boundary/axis enforcement step now uses concatenation instead of
scatter updates for the edge and axis rows. This trims scatter-heavy kernels in
the scan loop without changing the VMEC constraints.
Axis m=0 masks are now reused from VMECStatic to avoid per-iteration mask
construction.
Lambda gauge masking¶
The (m,n)=(0,0) lambda gauge constraint now uses a boolean mask instead of a scatter update, trimming another small scatter kernel from the iteration loop.
Concatenation-based mode updates¶
The current optimized branch also replaces several remaining scatter-style fixed-boundary hot-path updates with concatenation-based helpers:
zeroing the lambda gauge coefficient column,
replacing or scaling a single
m-mode slice in the preconditioned force blocks,and enforcing axis/edge constraints on the R/Z/L state fields.
This keeps the same constraints and algebra, but reduces the number of
scatter-heavy kernels created inside the iteration loop. On the representative
accelerated reactor-scale QA case
(input.LandremanPaul2021_QA_reactorScale_lowres), the warmed CPU runtime
improved from about 12.00s to about 11.04s on the same host. The same
pass also nudged the representative free-boundary warmed CPU benchmark
(input.cth_like_free_bdy) from about 3.11s to about 3.07s while
keeping the full test suite green.
Skip asymmetric updates when lasym=False¶
Several fixed/free-boundary update paths previously computed the asymmetric
Rsin/Zcos/Lcos signed-coefficient updates even for lasym=False cases,
then immediately zeroed them. The current branch now skips those conversions
entirely on the symmetric path.
This is a clean win for the shipped lasym=False workloads:
input.cth_like_free_bdyimproved from about3.07sto about2.91son the same warmed CPU benchmark,input.DIII-D_lasym_falsewithmax_iter=20improved from about0.262sto about0.262s(effectively neutral but not worse),the representative
lasym=Truefixed-boundary smoke (input.basic_non_stellsym_pressure) stayed on the same runtime and convergence profile because the asymmetric path is still fully active there.
Vectorized axis blending¶
Initial-guess axis blending now updates all m=0 Fourier columns in one
vectorized scatter instead of looping over toroidal modes. This reduces
index-heavy overhead during startup.
Cached mode scaling¶
VMECStatic now caches the per-mode internal scaling factors
1/(mscale*nscale) so initial-guess construction avoids repeated gathers
from the trig tables.
Historical free-boundary cache note¶
A May 2026 free-boundary cleanup pass cached more of the static external-boundary sampling setup on the host:
phase-independent second-derivative mode factors,
angular
phigrids for the sampled boundary mesh,even-
mmasks used by the axis-parity shortcut,VMEC angular
wintweights for the sampled mesh,and the optional SciPy LU helpers used by the VMEC-like dense solve path.
That keeps the vacuum-step loop focused on state-dependent work instead of
rebuilding static sampling metadata. On the representative
input.cth_like_free_bdy path, the warmed CPU runtime improved from about
8.00s to about 7.88s on the same host while preserving the same final
residuals and free-boundary sampling test coverage.
Avoid Python objects in jitted functions¶
JAX jit requires inputs to be arrays or PyTrees. vmec-jax makes the key
containers PyTrees:
VMECStateHelicalBasisGeom
If you build your own containers, follow the same approach.
Memory considerations¶
The current Fourier implementation stores (K, ntheta, nzeta) basis tensors
for cos/sin phases. This is acceptable for low-resolution validation cases, but
will become heavy for larger mpol/ntor.
Planned upgrades (post-parity):
factorized DFTs (theta/phi separable) using precomputed trig/weight tables,
FFT-based angular transforms only if they reproduce VMEC scaling and weights,
chunked evaluation in
theta/zetato reduce peak memory.