Writing a Physics Scheme¶
How to add your own parameterisation to JAX-GCM. The composable physics
infrastructure is meant to make this a one-file drop-in: you write a
PhysicsTerm subclass, declare what diagnostics you read and write, and
plug it into an existing factory. No edits to the model orchestrator, no
edits to the package’s Parameters aggregator, no edits to a monolithic
PhysicsData struct.
This document walks through the contract and shows a complete minimal
example. For the why behind the design, see
composable_physics.md.
The contract¶
A scheme is a PhysicsTerm subclass. The base class lives at
jcm/physics/physics_term.py and is just a thin
flax.nnx.Module wrapper:
from typing import ClassVar
from flax import nnx
from jcm.physics.physics_term import PhysicsTerm
class MyScheme(PhysicsTerm):
name: ClassVar[str] = "my_scheme"
category: ClassVar[str] = "convection"
requires: ClassVar[tuple[str, ...]] = ("pressure_full",)
provides: ClassVar[tuple[str, ...]] = ("my_scheme",)
def __init__(self, params: MyParameters | None = None):
self.params = nnx.Param(params or MyParameters.default())
def __call__(self, state, diagnostics, forcing, terrain):
...
return tendency, {**diagnostics, "my_scheme": MyData(...)}
The four ClassVars drive composition-time validation. The two methods
do the real work.
Static metadata¶
Field |
Meaning |
|---|---|
|
Human-readable identifier. Shown in error messages and used as a fallback diagnostic key. Keep it unique within a composition. |
|
The slot this term occupies ( |
|
Diagnostic keys that some upstream term must already have written. Validated at construction time — see Validation below. |
|
Diagnostic keys this term writes into the dict it returns. Validation flags two terms claiming the same key. |
The keys in requires / provides are the public-facing keys in the
diagnostics dict (pressure_full, air_density, convection,
radiation, …). Anything starting with _ is internal state (radiation
caching, the date struct) and should not appear in either field.
__init__¶
Hold tunable parameters as nnx.Param. That is what makes them
gradient-trainable through flax.nnx:
def __init__(self, params: MyParameters | None = None):
self.params = nnx.Param(params or MyParameters.default())
MyParameters is your own scheme’s parameters struct — typically a
@tree_math.struct dataclass. It is not the package-level Parameters
aggregator; each term owns only the parameters its scheme cares about.
If the term needs precomputed coordinate-dependent caches (e.g.
sigma layers, basis functions), put them in nnx.Variable attributes
inside an overridden cache_coords(coords). That hook is called once
at Model construction time and is the right place for non-trainable
arrays:
def cache_coords(self, coords):
self.sigma = nnx.Variable(jnp.asarray(coords.vertical.layer_centers))
__call__¶
def __call__(self, state, diagnostics, forcing, terrain):
return tendency, updated_diagnostics
Inputs:
state: aPhysicsState(winds, temperature, humidity, geopotential, surface pressure, tracers).diagnostics: adict[str, jnp.ndarray | pytree]containing everything upstream terms have written. Read by key; never mutate — always return a new dict.forcing: aForcingDataalready sliced to the current step (Model._get_step_fn_factorycallsForcingData.select(date)).terrain: aTerrainDatawithorog,fmask,gsl,vlt, …
Outputs:
A
PhysicsTendency(winds, temperature, humidity, tracers). Most schemes usePhysicsTendency.zero(...)as a base and.copy(...)to fill in just the fields they touch.A new diagnostics dict — typically
{**diagnostics, "my_key": data}.
The function must be pure: no side effects, no Python if/else on
JAX-traced values, no dict mutation. Use jax.lax.cond /
jnp.where for conditional logic.
Tracers¶
If your scheme reads or writes a non-default tracer (anything beyond
specific_humidity), declare it via required_tracers(). Model
collects specs from every term at build time and seeds the initial
state’s tracer dict accordingly:
from jcm.physics.physics_term import TracerSpec
class MyScheme(PhysicsTerm):
@classmethod
def required_tracers(cls):
return (
TracerSpec(name="qc", units="kg/kg", initial_value=0.0),
TracerSpec(
name="qnc", units="1/kg", initial_value=0.0,
nondimensionalize=False,
),
)
nondimensionalize=False is for tracers that aren’t mass mixing ratios
(e.g. number concentrations).
Column vectorisation¶
If the scheme is a single-column algorithm and the host package uses
vectorize_columns=True (as echam_physics() does), ComposablePhysics
wraps each term in a jax.vmap over the horizontal axes. Inside the
term you can write the calculation as if it acted on one column.
Most ECHAM terms are written this way; MoistAirColumnState is a
counter-example that runs on the full 3-D grid because it is upstream
of the per-column vmap.
Validation¶
ComposablePhysics(terms=[...]) runs _validate_ordering at
construction time:
Single-writer per key. If two terms list the same key in their
provides, the second would overwrite the first and is almost always a misconfiguration. Validation raises with the offending categories.All
requiresresolved upstream. Each term’srequiresmust appear in the union of upstream terms’provides. Errors point at the first unsatisfied dependency.
Most plugin authors hit this when they forget that the moist-air
diagnostics (pressure_full, pressure_half, air_density,
layer_thickness, …) are produced by the MoistAirColumnState term —
that has to be the first term in any column-physics composition.
Minimal complete example¶
A 30-line RayleighDamping term that adds a height-dependent linear
drag on horizontal winds, with a tunable timescale. Single file,
no other repo edits required.
# my_package/rayleigh_damping.py
from typing import ClassVar
import jax.numpy as jnp
import tree_math
from flax import nnx
from jcm.physics.physics_term import PhysicsTerm
from jcm.physics_interface import PhysicsTendency
@tree_math.struct
class RayleighDampingParameters:
timescale_seconds: float
sigma_top: float
@classmethod
def default(cls):
return cls(timescale_seconds=1.0 * 86400.0, sigma_top=0.1)
class RayleighDamping(PhysicsTerm):
name: ClassVar[str] = "rayleigh_damping"
category: ClassVar[str] = "gravity_waves"
requires: ClassVar[tuple[str, ...]] = ()
provides: ClassVar[tuple[str, ...]] = ()
def __init__(self, params: RayleighDampingParameters | None = None):
self.params = nnx.Param(
params or RayleighDampingParameters.default(),
)
def cache_coords(self, coords):
sigma = jnp.asarray(coords.vertical.layer_centers)
self.sigma = nnx.Variable(sigma)
def __call__(self, state, diagnostics, forcing, terrain):
p = self.params.value
# Smooth ramp from 0 below sigma_top to 1/timescale at the top.
weight = jnp.clip(
(p.sigma_top - self.sigma.value) / p.sigma_top, 0.0, 1.0,
)
rate = (weight / p.timescale_seconds)[:, None, None]
u_t = -rate * state.u_wind
v_t = -rate * state.v_wind
tend = PhysicsTendency.zero(state).copy(u_wind=u_t, v_wind=v_t)
return tend, diagnostics
The requires / provides are empty because the term reads only
the prognostic state and writes only tendencies — it does not consume
or produce any diagnostic keys.
Composing your term in¶
Once the term exists, the standard composition operators on
ComposablePhysics let you splice it in without editing anyone else’s
code:
from jcm.physics.echam.echam_terms import echam_physics
from my_package.rayleigh_damping import RayleighDamping
# Append: add the term at the end of the default ECHAM stack.
physics = echam_physics() + RayleighDamping()
# Replace: swap out the existing gravity_waves terms.
physics = echam_physics().replace("gravity_waves", RayleighDamping())
# Remove: drop a category, then build whatever you want on top.
physics = echam_physics().remove("clouds") + RayleighDamping()
replace(category, new_term) collapses all existing terms in that
category into the single replacement, inserted at the position of the
first one. remove(category) drops every term with that category.
+ concatenates two compositions. Each operation runs the validation
checks again on the resulting term list.
For SPEEDY:
from jcm.physics.speedy.speedy_terms import speedy_physics
physics = speedy_physics() + RayleighDamping()
The same drop-in works because the term contract is package-agnostic.
Where the term lives in the tree¶
The repo organises files by physical process, named after the scheme. Add your new file under the appropriate process directory:
jcm/physics/
├── convection/<scheme>.py
├── radiation/<scheme>.py
├── clouds/<scheme>.py
├── surface/<scheme>/
├── vertical_diffusion/<scheme>/
├── gravity_waves/<scheme>/
├── aerosol/<scheme>.py
├── chemistry/<scheme>.py
Third-party plugins can live anywhere on the import path. They do not
need to live inside jcm/.
Verifying¶
Two cheap checks before you commit:
JAX_PLATFORMS=cpu pytest -n 12 -m "not slow"— the regular fast suite. Composing your term intoecham_physics()orspeedy_physics()and running an existing smoke test is a good way to make sure it survives a step.A short gradient check via
jax.value_and_gradto confirmflax.nnx.Paramis wired up correctly. If the gradient w.r.t. your parameters isNoneor all-zero, thennx.Paramdecoration is missing somewhere.
Anti-patterns¶
Don’t subclass
ComposablePhysics. The container is final; composition happens through the+,replace, andremoveoperators.Don’t mutate
diagnosticsin place. Always return{**diagnostics, "my_key": ...}.Don’t read the model timestep from a parameter struct. Use
diagnostics["_date"].dt_seconds. The model dt may change across configurations and should never be smuggled throughParameters.Don’t add a leading underscore to a
provideskey unless the data is genuinely internal (caches, transient state). Public keys flatten directly intomodel.run().to_xarray()output.Don’t import from
jcm/physics/echam/to make a SPEEDY-side term work. If you need shared infrastructure, put it under a scheme-neutral location likejcm/physics/diagnostics/.