Composable physics

Overview

A JCM physics package is a ComposablePhysics container holding an ordered list of PhysicsTerm instances. Each term is a self-contained parameterisation that reads the prognostic state and a shared diagnostics dict, returns its tendency, and writes its outputs back into the dict for downstream terms to consume.

Model.compute_physics_step
   └─ ComposablePhysics.compute_tendencies(state, forcing, terrain, prev_carry)
        diagnostics = {**prev_carry}    ← cross-step physics carry seed
        for term in terms:
            tend, diagnostics = term(state, diagnostics, forcing, terrain)
            tendencies += tend
        return tendencies, diagnostics  ← diagnostics is the next step's carry

ComposablePhysics is flax.nnx.Module, every PhysicsTerm is flax.nnx.Module, and per-term parameters are nnx.Param. Composition is differentiable end-to-end via either nnx.grad directly on the container or jax.grad over a nnx.split-flattened state.

The container is final — it is not subclassed. Composition happens at construction time via +, replace(category, term), and remove(category). _validate_ordering runs each time a composition is built.

Components

PhysicsTerm (jcm/physics/physics_term.py)

Base class for one parameterisation. Each subclass declares four ClassVars of static metadata and implements two methods.

class PhysicsTerm(nnx.Module):
    name:     ClassVar[str]                  # unique identifier
    category: ClassVar[str]                  # "radiation", "convection", …
    requires: ClassVar[tuple[str, ...]] = () # diagnostics keys read from upstream
    provides: ClassVar[tuple[str, ...]] = () # diagnostics keys written

    def cache_coords(self, coords) -> None:
        """Populate coordinate-dependent caches as nnx.Variable. In-place.
        Called once at Model construction time, outside any traced region."""

    def __call__(
        self,
        state: PhysicsState,
        diagnostics: dict[str, Any],
        forcing: ForcingData,
        terrain: TerrainData,
    ) -> tuple[PhysicsTendency, dict[str, Any]]:
        """Return (tendency, updated_diagnostics)."""

Storage convention:

  • nnx.Param — tunable parameters. nnx.grad differentiates through them.

  • nnx.Variable — coordinate caches and other traced-but-frozen state. Read inside __call__ as live pytree leaves; not differentiated by default.

  • Plain Python attributes — static configuration (flags, integer knobs that should not change post-construction).

ComposablePhysics (jcm/physics/composable_physics.py)

The single container class. Iterates terms in order, threads the diagnostics dict through, sums tendencies, and exposes the composition operators.

class ComposablePhysics(nnx.Module, Physics):
    def __init__(self, terms: list[PhysicsTerm], *, vectorize_columns: bool = False):
        self.terms = terms
        self._validate_ordering()
        self.vectorize_columns = vectorize_columns

    def cache_coords(self, coords):
        for term in self.terms:
            term.cache_coords(coords)

    def compute_tendencies(self, state, forcing, terrain, *, prev_physics_data=None):
        diagnostics = dict(prev_physics_data) if prev_physics_data else {}
        tendencies = PhysicsTendency.zeros(state.temperature.shape)
        for term in self.terms:
            tend, diagnostics = term(state, diagnostics, forcing, terrain)
            tendencies += tend
        return tendencies, diagnostics

    # Composition operators (each returns a fresh container, runs validation):
    def __add__(self, other):           ...   # concatenate term lists
    def replace(self, category, new):   ...   # swap all terms of category
    def remove(self, category):         ...   # drop all terms of category

vectorize_columns=True (used by echam_physics()) wraps each term in jax.vmap over the horizontal axes so single-column algorithms can be written as if they acted on one column. The outer container handles the reshape/un-reshape; terms see (nlev,) or (nlev, ncols) as appropriate.

Pre-built factories

Each physics package is a factory that returns a ComposablePhysics with a validated ordering. The factories live next to their term files:

  • jcm/physics/speedy/speedy_terms.py::speedy_physics()

  • jcm/physics/echam/echam_terms.py::echam_physics()

  • jcm/physics/held_suarez/held_suarez_physics.py::held_suarez_physics()

Held-Suarez stays a hand-written package — composable machinery is optional, not forced.

The diagnostics dict

Terms communicate through dict[str, Any] (typically dict[str, jnp.ndarray] or typed sub-structs of arrays). The dict flows forward through the term list, every term reads the keys it needs and returns a new dict with any keys it produces — never mutate in place.

Two key conventions:

  • Public keys (no leading underscore): exposed as user-facing diagnostic output. Flatten directly into model.run().to_xarray() via Physics.data_struct_to_dict. Example: radiation, convection, cloud_fraction.

  • Internal keys (leading underscore): cross-step or transient state used by terms internally; filtered out of user-facing output. Examples: _radiation (sub-cycle cache), _date (jax_datetime struct + dt_seconds carried for terms that need it), _dt_seconds (model timestep injected by ComposablePhysics so terms read a single source of truth instead of plumbing it through parameters or date).

The diagnostics dict that comes out of compute_tendencies is the cross-step physics carry — operator-split integration threads it as a JAX pytree through lax.scan (see operator_split_physics.md).

Validation

ComposablePhysics(terms=[...]) runs _validate_ordering at construction time:

  1. Single-writer per key. Two terms cannot list the same key in their provides. Catches misconfigurations where one term would silently overwrite another’s output.

  2. All requires resolved upstream. Each term’s requires must appear in the union of upstream terms’ provides. Catches ordering bugs at Model construction rather than at the first model.run().

Empty requires / provides are fine — terms that read only the prognostic state and write only tendencies (e.g. a Rayleigh damping term) declare nothing.

required_tracers() is a separate hook each term can implement to declare non-default tracers (anything beyond specific_humidity). Model collects specs from every term at build time and seeds the initial state’s tracer dict.

Composition operators

Three operators on the container, each returning a fresh ComposablePhysics after re-running validation:

physics_a + physics_b                 # concatenate term lists
physics.replace("convection", new)    # swap all 'convection' terms for `new`
physics.remove("clouds")              # drop every term whose category is 'clouds'

replace(category, new_term) collapses all existing terms in that category into the single replacement, inserted at the position of the first one — so you can swap an entire process category in one call:

# SPEEDY with RRTMGP radiation
physics = speedy_physics().replace("radiation", RRTMGPRadiation())

# ECHAM with a custom convection scheme
physics = echam_physics().replace("convection", MyConvection())

# Strip clouds entirely, then add Rayleigh damping
physics = echam_physics().remove("clouds") + RayleighDamping()

When users compose from scratch (term_a + term_b + term_c), they accept responsibility for ordering correctness — _validate_ordering catches dependency bugs but cannot enforce semantic order beyond the requires / provides graph.

Cross-package compatibility

SPEEDY and ECHAM terms historically used different intermediate data representations. The diagnostics dict is the bridge:

  • SPEEDY wrappers store SPEEDY sub-structs under keys like "_shortwave_rad", "_convection".

  • ECHAM wrappers store ECHAM sub-structs under keys like "radiation", "convection".

Mixing terms across packages works cleanly when the replacement covers an entire process category — the new term reads only public diagnostics and prognostic state, and produces what downstream terms need from scratch. Sharing internal sub-structs across packages requires either a translation term or independent recomputation; in practice the entire-category replacement pattern is what most use cases want.

Differentiability

End-to-end gradient flow is preserved through the diagnostics dict (all values are JAX arrays or pytrees of arrays), per-term nnx.Param storage, and the composability operators (which produce fresh ComposablePhysics modules with the same nnx graph properties).

Pattern 1 — direct nnx.grad (most ergonomic):

physics = speedy_physics()
physics.cache_coords(coords)   # run ONCE, outside the traced region

def loss_fn(physics):
    model = Model(coords=coords, terrain=terrain, physics=physics)
    return compute_loss(model.run(total_time=...))

grads = nnx.grad(loss_fn)(physics)
# grads.terms[i].params is the gradient w.r.t. term i's parameters

Pattern 2 — pure JAX via split/merge (interop with existing jax.grad code):

graphdef, state = nnx.split(physics)

def loss_fn(state):
    physics = nnx.merge(graphdef, state)
    model = Model(coords=coords, terrain=terrain, physics=physics)
    return compute_loss(model.run(total_time=...))

grads = jax.grad(loss_fn)(state)

Pattern 3 — per-scheme optimisation:

# Optimise only convection parameters, freeze the rest
convection_filter = nnx.PathContains("convection")
grads = nnx.grad(loss_fn, wrt=convection_filter)(physics)

# Or address a single term's params directly
opt = optax.adam(1e-3)
opt_state = opt.init(physics.terms[3].params)

Path-based filtering is the largest ergonomic improvement over a monolithic Parameters struct: optimising one scheme in isolation needs no surgery.

cache_coords lifecycle

cache_coords(coords) is called once at Model.__init__ time, before any traced region:

self.physics = physics
self.physics.cache_coords(self.coords)

Inside cache_coords, a term stores precomputed data (sigma layer midpoints, basis functions, lookup tables) as nnx.Variable attributes. During model.run() those variables are read as traced values but never written.

The default behaviour of nnx.grad is to ignore nnx.Variable and only differentiate nnx.Param — so coordinate caches do not appear in the gradient. Callers that do want to differentiate w.r.t. the vertical-level placement (rare, but supported for learnable level schemes) can broaden the path filter.

Process-parallel coupling

Within a ComposablePhysics step, terms run process-parallel: every term reads the same input prognostic state, and tendencies are summed. Terms may read each other’s diagnostic outputs (through the dict), but they do not see each other’s tendency contribution applied to the prognostic state until the next dynamics step.

This is order-independent at the prognostic-state level — A + B + C and B + A + C produce the same total tendency from the same state. It differs from ECHAM6’s sequential coupling, where each scheme reads the state with prior schemes’ tendencies already added (via the tte += ... pattern on shared accumulators in mo_scan_buffer.f90). Sequential coupling is more accurate for tightly-coupled process pairs but gives up the order-independence that makes replace() / remove() semantically clean.

For terms that genuinely need sequential coupling (e.g. a tightly coupled CLUBB+MG2 pair as in E3SM), the recommended pattern is a process group term that runs the inner sub-cycle internally and presents a single tendency externally — keeping the outer container process-parallel.

Plugin contract

A third-party scheme is a single-file drop-in. See writing_a_physics_scheme.md for a walkthrough; the contract is:

# my_package/my_scheme.py
from typing import ClassVar
from flax import nnx
from jcm.physics.physics_term import PhysicsTerm

class MyScheme(PhysicsTerm):
    name:     ClassVar[str] = "my_scheme"
    category: ClassVar[str] = "convection"
    requires: ClassVar[tuple[str, ...]] = ("pressure_full",)
    provides: ClassVar[tuple[str, ...]] = ("my_scheme",)

    def __init__(self, params=None):
        self.params = nnx.Param(params or MySchemeParameters.default())

    def __call__(self, state, diagnostics, forcing, terrain):
        ...
        return tendency, {**diagnostics, "my_scheme": MyData(...)}
# user code
from jcm.physics.echam.echam_terms import echam_physics
from my_package.my_scheme import MyScheme

physics = echam_physics().replace("convection", MyScheme())

ComposablePhysics validates the new term list at construction time: every requires has an upstream provides; no key is provided twice; required_tracers() are seeded into the initial state by Model. No edits to the model orchestrator, no edits to the package’s parameters aggregator, no edits to a monolithic data struct.

Directory layout

Physics is organised by physical process, with files named after the scheme rather than the model they were ported from. New ports of the same scheme drop in beside the existing one without an extra per-model subfolder:

jcm/physics/
├── physics_term.py             # PhysicsTerm base class
├── composable_physics.py       # ComposablePhysics container
├── radiation/
│   ├── grey_two_stream/        # fast grey two-stream package
│   ├── rrtmgp.py               # RRTMGP wrapper
│   ├── nn_emulator.py          # NN radiation emulator
│   ├── speedy_shortwave.py
│   └── speedy_longwave.py
├── convection/
│   ├── tiedtke_nordeng/        # Tiedtke-Nordeng mass flux
│   └── speedy_convection.py
├── clouds/
│   ├── sundqvist.py            # Sundqvist diagnostic cloud fraction
│   ├── echam_1m.py             # ECHAM 1-moment microphysics
│   ├── speedy_humidity.py
│   └── speedy_condensation.py
├── vertical_diffusion/
│   ├── tte_tke/                # TTE-TKE closure
│   └── speedy_vdiff.py
├── gravity_waves/{hines,sso,simple}/
├── aerosol/macv2_sp.py         # Stevens MACv2-SP simple plumes
├── chemistry/simple_chemistry.py
├── surface/                    # SPEEDY and ECHAM surface schemes
├── speedy/                     # SPEEDY infrastructure (params, coords)
└── echam/                      # ECHAM infrastructure (params, coords)

Model-specific infrastructure (parameter containers, coordinate caches, data structs) lives under speedy/ and echam/. Everything else is named after the scheme so an ECHAM port and a CAM port of the same parameterisation sit side-by-side.