vmec_jax.state

State layout helpers.

This module defines the coefficient container used throughout vmec_jax.

Important for JAX:
  • VMECState must be a PyTree so it can be passed into jax.jit’d functions (e.g. eval_coords) and differentiated with jax.grad.

  • Registration must be idempotent: in interactive workflows, the module can be imported multiple times (or reloaded) and JAX will otherwise raise a duplicate registration error.

Functions

pack_state(state)

Pack a VMECState into a flat vector (ns*K*6).

unpack_state(x, layout)

Unpack a flat vector into a VMECState.

zeros_state(layout, *[, like])

Create a zero-initialized state with the right shapes.

Classes

StateLayout(ns, K, lasym)