Getting Started

Installation

To use JAX-GCM, first install it:

$ pip install jcm

or for the development version:

$ git clone https://github.com/climate-analytics-lab/jax-gcm.git
$ cd jax-gcm
$ pip install -e .

Requirements

  • Python ≥ 3.11

  • JAX

  • Dinosaur (dynamical core)

  • XArray (for I/O and data handling)

See requirements.txt for the complete list of dependencies.

Quick Start Examples

Aquaplanet Simulation

An aquaplanet simulation is the simplest configuration - a water-covered planet with no orography and constant (zonally symmetric) forcing. This is ideal for learning the model and testing new physics:

from jcm.model import Model
from jcm.physics.speedy.speedy_coords import get_speedy_coords

# Create a model with default aquaplanet configuration
model = Model(
   coords=get_speedy_coords(),  # T31 spectral resolution with 8 vertical levels
   time_step=30.0  # minutes
)

# Run a 120-day simulation
predictions = model.run(
   save_interval=10.0,  # save every 10 days
   total_time=120.0     # total simulation time in days
)

# Convert output to xarray Dataset for analysis
ds = predictions.to_xarray()
print(ds)

This creates a T31 spectral resolution model (96x48 grid points) with 8 vertical levels using the SPEEDY physics package. The default forcing includes zonally symmetric sea surface temperatures and no land.

Realistic Simulation

For a more realistic simulation with orography and time-varying boundary conditions, you can load data from files:

from jcm.model import Model
from jcm.terrain import TerrainData
from jcm.forcing import ForcingData
from importlib import resources

coords = get_speedy_coords()  # T31 spectral resolution with 8 vertical levels

# Load realistic orography and land-sea mask, interpolated to T31 grid
data_dir = resources.files("jcm.data.bc.t30.clim")
terrain_file = data_dir / "terrain.nc"
terrain = TerrainData.from_file(terrain_file, coords=coords)

# Load realistic forcing data (SST, sea ice, soil moisture, etc.) interpolated to T31 grid
forcing_file = data_dir / "forcing.nc"
forcing = ForcingData.from_file(forcing_file, coords=coords)

# Create model with realistic configuration
model = Model(
   coords,
   time_step=30.0,
   terrain=terrain
)

# Run simulation
predictions = model.run(
   forcing=forcing,
   save_interval=5.0,   # save every 5 days
   total_time=30.0      # 30-day simulation
)

# Convert to xarray and save
ds = predictions.to_xarray()
ds.to_netcdf("output.nc")

Customizing the Model

You can customize various aspects of the model:

Resolution: Change the horizontal and vertical resolution

from jcm.terrain import TerrainData
from jcm.physics.speedy.speedy_coords import get_speedy_coords

# Higher resolution: T85 (256x128 grid)
coords = get_speedy_coords(spectral_truncation=85)
terrain = TerrainData.aquaplanet(coords=coords)

model = Model(
   coords=coords,
   time_step=20.0,  # smaller timestep for stability
   terrain=terrain
)

Physics: Use different physics packages or configurations

from jcm.physics.speedy.speedy_physics import SpeedyPhysics
from jcm.physics.speedy.params import Parameters
from jcm.physics.speedy.speedy_coords import get_speedy_coords

# Customize physics parameters
params = Parameters.default()
params = params.replace(...)  # modify parameters as needed

physics = SpeedyPhysics(parameters=params)

model = Model(
   coords=get_speedy_coords(),
   time_step=30.0,
   physics=physics
)

Initial Conditions: Start from a specific state

from jcm.physics_interface import PhysicsState

# Create or load initial state
# initial_state = PhysicsState(...)

predictions = model.run(
    initial_state=initial_state,
    save_interval=1.0,
    total_time=10.0
)

Multi-Device Parallelization

JCM supports multi-device parallelization using JAX’s SPMD (Single Program Multiple Data) sharding. This allows you to split computation across multiple GPUs or TPUs for faster execution, especially useful for higher resolution simulations.

If you don’t specify spmd_mesh, JCM runs on a single device by default. This is the recommended approach for smaller resolutions (T31, T42) or when you only have a single GPU/TPU available.

Basic Concepts

SPMD Mesh: Defines how to partition data across devices. The mesh has three dimensions corresponding to (x, y, z) or (longitude, latitude, vertical).

Sharding Strategy: Typically, for SPEEDY Physics simulations, you want to shard the longitude dimension first since it usually has the most grid points. For Physics implementations with more layers (e.g. 32 or 64 layers) however, you may find that sharding the dycore in the vertical dimension to be most effective. Future implementations may allow for more flexible sharding strategies.

Enabling Parallelization

To enable multi-device parallelization, simply pass an spmd_mesh when creating your Model:

import jax
from jcm.model import Model
from jcm.physics.speedy.speedy_coords import get_speedy_coords

# Check available devices
print(f"Available devices: {jax.devices()}")
print(f"Number of devices: {len(jax.devices())}")

# Define a mesh to split longitude across 4 devices
# Mesh shape (4, 1, 1) means:
#   - Split longitude dimension across 4 devices
#   - Don't split latitude (1)
#   - Don't split vertical (1)
# Otherwise, create and run model as usual
model = Model(coords=get_speedy_coords(), spmd_mesh=(4, 1, 1))
predictions = model.run(save_interval=5.0, total_time=30.0)

Mesh Configuration Guidelines

The product of mesh dimensions must equal the number of available devices:

  • (4, 1, 1): Split longitude across 4 devices

  • (2, 2, 1): Split longitude (2) and latitude (2) across 4 devices total

  • (8, 1, 1): Split longitude across 8 devices (for higher resolutions)

Rules of thumb:

  1. Product of mesh dimensions = number of devices

  2. Longitude (x) usually has most grid points → split first

  3. Higher resolutions (T85+) benefit more from sharding

Analyzing Output

The model output is a Predictions object containing the model state trajectory. Convert it to xarray for analysis:

import matplotlib.pyplot as plt

# Convert to xarray Dataset
ds = predictions.to_xarray()

# Print variables
print(ds.data_vars)

# Plot surface temperature evolution
ds['temperature'].isel(level=7).mean(dim='lon').plot()
plt.title('Zonal Mean Surface Temperature')
plt.show()

# Calculate global mean quantities
global_mean_temp = ds['temperature'].weighted(
    ds['lat'].pipe(lambda x: np.cos(np.deg2rad(x)))
).mean(dim=['lon', 'lat'])

Next Steps

  • See Architecture & Design to understand the model architecture

  • See API for detailed API documentation

  • Check example notebooks in the notebooks/ directory of the GitHub repo

  • Read Developer Guide for contribution guidelines