jcm.checkpoint

Model state checkpointing for long, preemptible runs.

Persists Model._final_dycore_state and Model._final_physics_state plus an elapsed sim-day count to a single file using flax’s msgpack serialization. run_chunked (in jcm.runners) integrates with these primitives via cfg.run.checkpoint_path — when set, it writes a checkpoint after each chunk and restores from one at startup if the file exists, so an integration interrupted by spot-instance preemption resumes without redoing completed chunks.

The state pytrees are flattened to plain lists of arrays before serialization because flax’s msgpack codec can’t handle tree_math structs (e.g. dinosaur’s primitive_equations.State) directly. The treedef is reconstructed at load time from the destination model’s bootstrapped templates — this makes a checkpoint portable only across runs with matching dycore + coords + physics term composition (where the leaf order and dtypes line up), which is the intended usage.

Functions

load_checkpoint(model, path)

Restore _final_dycore_state + _final_physics_state from path.

save_checkpoint(model, path, *, elapsed_days)

Persist the model's current dycore + physics state to path.