vmec_jax.state¶
State layout helpers.
This module defines the coefficient container used throughout vmec_jax.
- Important for JAX:
VMECStatemust be a PyTree so it can be passed intojax.jit’d functions (e.g.eval_coords) and differentiated withjax.grad.Registration must be idempotent: in interactive workflows, the module can be imported multiple times (or reloaded) and JAX will otherwise raise a duplicate registration error.
Functions
|
Pack a VMECState into a flat vector (ns*K*6). |
|
Unpack a flat vector into a VMECState. |
|
Create a zero-initialized state with the right shapes. |
Classes
|