jcm.checkpoint¶
Model state checkpointing for long, preemptible runs.
Persists Model._final_dycore_state and Model._final_physics_state
plus an elapsed sim-day count to a single file using flax’s msgpack
serialization. run_chunked (in jcm.runners) integrates with
these primitives via cfg.run.checkpoint_path — when set, it writes a
checkpoint after each chunk and restores from one at startup if the file
exists, so an integration interrupted by spot-instance preemption resumes
without redoing completed chunks.
The state pytrees are flattened to plain lists of arrays before
serialization because flax’s msgpack codec can’t handle tree_math
structs (e.g. dinosaur’s primitive_equations.State) directly. The
treedef is reconstructed at load time from the destination model’s
bootstrapped templates — this makes a checkpoint portable only across
runs with matching dycore + coords + physics term composition (where the
leaf order and dtypes line up), which is the intended usage.
Functions
|
Restore |
|
Persist the model's current dycore + physics state to |