Composable physics¶
Overview¶
A JCM physics package is a ComposablePhysics container holding an
ordered list of PhysicsTerm instances. Each term is a self-contained
parameterisation that reads the prognostic state and a shared
diagnostics dict, returns its tendency, and writes its outputs back
into the dict for downstream terms to consume.
Model.compute_physics_step
└─ ComposablePhysics.compute_tendencies(state, forcing, terrain, prev_carry)
diagnostics = {**prev_carry} ← cross-step physics carry seed
for term in terms:
tend, diagnostics = term(state, diagnostics, forcing, terrain)
tendencies += tend
return tendencies, diagnostics ← diagnostics is the next step's carry
ComposablePhysics is flax.nnx.Module, every PhysicsTerm is
flax.nnx.Module, and per-term parameters are nnx.Param. Composition
is differentiable end-to-end via either nnx.grad directly on the
container or jax.grad over a nnx.split-flattened state.
The container is final — it is not subclassed. Composition happens at
construction time via +, replace(category, term), and
remove(category). _validate_ordering runs each time a composition
is built.
Components¶
PhysicsTerm (jcm/physics/physics_term.py)¶
Base class for one parameterisation. Each subclass declares four
ClassVars of static metadata and implements two methods.
class PhysicsTerm(nnx.Module):
name: ClassVar[str] # unique identifier
category: ClassVar[str] # "radiation", "convection", …
requires: ClassVar[tuple[str, ...]] = () # diagnostics keys read from upstream
provides: ClassVar[tuple[str, ...]] = () # diagnostics keys written
def cache_coords(self, coords) -> None:
"""Populate coordinate-dependent caches as nnx.Variable. In-place.
Called once at Model construction time, outside any traced region."""
def __call__(
self,
state: PhysicsState,
diagnostics: dict[str, Any],
forcing: ForcingData,
terrain: TerrainData,
) -> tuple[PhysicsTendency, dict[str, Any]]:
"""Return (tendency, updated_diagnostics)."""
Storage convention:
nnx.Param— tunable parameters.nnx.graddifferentiates through them.nnx.Variable— coordinate caches and other traced-but-frozen state. Read inside__call__as live pytree leaves; not differentiated by default.Plain Python attributes — static configuration (flags, integer knobs that should not change post-construction).
ComposablePhysics (jcm/physics/composable_physics.py)¶
The single container class. Iterates terms in order, threads the
diagnostics dict through, sums tendencies, and exposes the
composition operators.
class ComposablePhysics(nnx.Module, Physics):
def __init__(self, terms: list[PhysicsTerm], *, vectorize_columns: bool = False):
self.terms = terms
self._validate_ordering()
self.vectorize_columns = vectorize_columns
def cache_coords(self, coords):
for term in self.terms:
term.cache_coords(coords)
def compute_tendencies(self, state, forcing, terrain, *, prev_physics_data=None):
diagnostics = dict(prev_physics_data) if prev_physics_data else {}
tendencies = PhysicsTendency.zeros(state.temperature.shape)
for term in self.terms:
tend, diagnostics = term(state, diagnostics, forcing, terrain)
tendencies += tend
return tendencies, diagnostics
# Composition operators (each returns a fresh container, runs validation):
def __add__(self, other): ... # concatenate term lists
def replace(self, category, new): ... # swap all terms of category
def remove(self, category): ... # drop all terms of category
vectorize_columns=True (used by echam_physics()) wraps each term in
jax.vmap over the horizontal axes so single-column algorithms can be
written as if they acted on one column. The outer container handles the
reshape/un-reshape; terms see (nlev,) or (nlev, ncols) as appropriate.
Pre-built factories¶
Each physics package is a factory that returns a ComposablePhysics
with a validated ordering. The factories live next to their term files:
jcm/physics/speedy/speedy_terms.py::speedy_physics()jcm/physics/echam/echam_terms.py::echam_physics()jcm/physics/held_suarez/held_suarez_physics.py::held_suarez_physics()
Held-Suarez stays a hand-written package — composable machinery is optional, not forced.
The diagnostics dict¶
Terms communicate through dict[str, Any] (typically
dict[str, jnp.ndarray] or typed sub-structs of arrays). The dict
flows forward through the term list, every term reads the keys it
needs and returns a new dict with any keys it produces — never
mutate in place.
Two key conventions:
Public keys (no leading underscore): exposed as user-facing diagnostic output. Flatten directly into
model.run().to_xarray()viaPhysics.data_struct_to_dict. Example:radiation,convection,cloud_fraction.Internal keys (leading underscore): cross-step or transient state used by terms internally; filtered out of user-facing output. Examples:
_radiation(sub-cycle cache),_date(jax_datetime struct + dt_seconds carried for terms that need it),_dt_seconds(model timestep injected byComposablePhysicsso terms read a single source of truth instead of plumbing it through parameters or date).
The diagnostics dict that comes out of compute_tendencies is the
cross-step physics carry — operator-split integration threads it as a
JAX pytree through lax.scan (see operator_split_physics.md).
Validation¶
ComposablePhysics(terms=[...]) runs _validate_ordering at
construction time:
Single-writer per key. Two terms cannot list the same key in their
provides. Catches misconfigurations where one term would silently overwrite another’s output.All
requiresresolved upstream. Each term’srequiresmust appear in the union of upstream terms’provides. Catches ordering bugs atModelconstruction rather than at the firstmodel.run().
Empty requires / provides are fine — terms that read only the
prognostic state and write only tendencies (e.g. a Rayleigh damping
term) declare nothing.
required_tracers() is a separate hook each term can implement to
declare non-default tracers (anything beyond specific_humidity).
Model collects specs from every term at build time and seeds the
initial state’s tracer dict.
Composition operators¶
Three operators on the container, each returning a fresh
ComposablePhysics after re-running validation:
physics_a + physics_b # concatenate term lists
physics.replace("convection", new) # swap all 'convection' terms for `new`
physics.remove("clouds") # drop every term whose category is 'clouds'
replace(category, new_term) collapses all existing terms in that
category into the single replacement, inserted at the position of the
first one — so you can swap an entire process category in one call:
# SPEEDY with RRTMGP radiation
physics = speedy_physics().replace("radiation", RRTMGPRadiation())
# ECHAM with a custom convection scheme
physics = echam_physics().replace("convection", MyConvection())
# Strip clouds entirely, then add Rayleigh damping
physics = echam_physics().remove("clouds") + RayleighDamping()
When users compose from scratch (term_a + term_b + term_c), they
accept responsibility for ordering correctness — _validate_ordering
catches dependency bugs but cannot enforce semantic order beyond the
requires / provides graph.
Cross-package compatibility¶
SPEEDY and ECHAM terms historically used different intermediate data representations. The diagnostics dict is the bridge:
SPEEDY wrappers store SPEEDY sub-structs under keys like
"_shortwave_rad","_convection".ECHAM wrappers store ECHAM sub-structs under keys like
"radiation","convection".
Mixing terms across packages works cleanly when the replacement covers an entire process category — the new term reads only public diagnostics and prognostic state, and produces what downstream terms need from scratch. Sharing internal sub-structs across packages requires either a translation term or independent recomputation; in practice the entire-category replacement pattern is what most use cases want.
Differentiability¶
End-to-end gradient flow is preserved through the diagnostics dict
(all values are JAX arrays or pytrees of arrays), per-term
nnx.Param storage, and the composability operators (which produce
fresh ComposablePhysics modules with the same nnx graph
properties).
Pattern 1 — direct nnx.grad (most ergonomic):
physics = speedy_physics()
physics.cache_coords(coords) # run ONCE, outside the traced region
def loss_fn(physics):
model = Model(coords=coords, terrain=terrain, physics=physics)
return compute_loss(model.run(total_time=...))
grads = nnx.grad(loss_fn)(physics)
# grads.terms[i].params is the gradient w.r.t. term i's parameters
Pattern 2 — pure JAX via split/merge (interop with existing jax.grad code):
graphdef, state = nnx.split(physics)
def loss_fn(state):
physics = nnx.merge(graphdef, state)
model = Model(coords=coords, terrain=terrain, physics=physics)
return compute_loss(model.run(total_time=...))
grads = jax.grad(loss_fn)(state)
Pattern 3 — per-scheme optimisation:
# Optimise only convection parameters, freeze the rest
convection_filter = nnx.PathContains("convection")
grads = nnx.grad(loss_fn, wrt=convection_filter)(physics)
# Or address a single term's params directly
opt = optax.adam(1e-3)
opt_state = opt.init(physics.terms[3].params)
Path-based filtering is the largest ergonomic improvement over a
monolithic Parameters struct: optimising one scheme in isolation
needs no surgery.
cache_coords lifecycle¶
cache_coords(coords) is called once at Model.__init__ time, before
any traced region:
self.physics = physics
self.physics.cache_coords(self.coords)
Inside cache_coords, a term stores precomputed data (sigma layer
midpoints, basis functions, lookup tables) as nnx.Variable
attributes. During model.run() those variables are read as traced
values but never written.
The default behaviour of nnx.grad is to ignore nnx.Variable and
only differentiate nnx.Param — so coordinate caches do not appear in
the gradient. Callers that do want to differentiate w.r.t. the
vertical-level placement (rare, but supported for learnable level
schemes) can broaden the path filter.
Process-parallel coupling¶
Within a ComposablePhysics step, terms run process-parallel:
every term reads the same input prognostic state, and tendencies
are summed. Terms may read each other’s diagnostic outputs
(through the dict), but they do not see each other’s tendency
contribution applied to the prognostic state until the next dynamics
step.
This is order-independent at the prognostic-state level —
A + B + C and B + A + C produce the same total tendency from the
same state. It differs from ECHAM6’s sequential coupling, where each
scheme reads the state with prior schemes’ tendencies already added
(via the tte += ... pattern on shared accumulators in
mo_scan_buffer.f90). Sequential coupling is more accurate for
tightly-coupled process pairs but gives up the order-independence
that makes replace() / remove() semantically clean.
For terms that genuinely need sequential coupling (e.g. a tightly coupled CLUBB+MG2 pair as in E3SM), the recommended pattern is a process group term that runs the inner sub-cycle internally and presents a single tendency externally — keeping the outer container process-parallel.
Plugin contract¶
A third-party scheme is a single-file drop-in. See
writing_a_physics_scheme.md for a
walkthrough; the contract is:
# my_package/my_scheme.py
from typing import ClassVar
from flax import nnx
from jcm.physics.physics_term import PhysicsTerm
class MyScheme(PhysicsTerm):
name: ClassVar[str] = "my_scheme"
category: ClassVar[str] = "convection"
requires: ClassVar[tuple[str, ...]] = ("pressure_full",)
provides: ClassVar[tuple[str, ...]] = ("my_scheme",)
def __init__(self, params=None):
self.params = nnx.Param(params or MySchemeParameters.default())
def __call__(self, state, diagnostics, forcing, terrain):
...
return tendency, {**diagnostics, "my_scheme": MyData(...)}
# user code
from jcm.physics.echam.echam_terms import echam_physics
from my_package.my_scheme import MyScheme
physics = echam_physics().replace("convection", MyScheme())
ComposablePhysics validates the new term list at construction time:
every requires has an upstream provides; no key is provided
twice; required_tracers() are seeded into the initial state by
Model. No edits to the model orchestrator, no edits to the
package’s parameters aggregator, no edits to a monolithic data
struct.
Directory layout¶
Physics is organised by physical process, with files named after the scheme rather than the model they were ported from. New ports of the same scheme drop in beside the existing one without an extra per-model subfolder:
jcm/physics/
├── physics_term.py # PhysicsTerm base class
├── composable_physics.py # ComposablePhysics container
├── radiation/
│ ├── grey_two_stream/ # fast grey two-stream package
│ ├── rrtmgp.py # RRTMGP wrapper
│ ├── nn_emulator.py # NN radiation emulator
│ ├── speedy_shortwave.py
│ └── speedy_longwave.py
├── convection/
│ ├── tiedtke_nordeng/ # Tiedtke-Nordeng mass flux
│ └── speedy_convection.py
├── clouds/
│ ├── sundqvist.py # Sundqvist diagnostic cloud fraction
│ ├── echam_1m.py # ECHAM 1-moment microphysics
│ ├── speedy_humidity.py
│ └── speedy_condensation.py
├── vertical_diffusion/
│ ├── tte_tke/ # TTE-TKE closure
│ └── speedy_vdiff.py
├── gravity_waves/{hines,sso,simple}/
├── aerosol/macv2_sp.py # Stevens MACv2-SP simple plumes
├── chemistry/simple_chemistry.py
├── surface/ # SPEEDY and ECHAM surface schemes
├── speedy/ # SPEEDY infrastructure (params, coords)
└── echam/ # ECHAM infrastructure (params, coords)
Model-specific infrastructure (parameter containers, coordinate
caches, data structs) lives under speedy/ and echam/. Everything
else is named after the scheme so an ECHAM port and a CAM port of
the same parameterisation sit side-by-side.