Getting Started¶
Installation¶
To use JAX-GCM, first install it:
$ pip install jcm
or for the development version:
$ git clone https://github.com/climate-analytics-lab/jax-gcm.git
$ cd jax-gcm
$ pip install -e .
Requirements¶
Python ≥ 3.11
JAX
Dinosaur (dynamical core)
XArray (for I/O and data handling)
See requirements.txt for the complete list of dependencies.
Quick Start Examples¶
Aquaplanet Simulation¶
An aquaplanet simulation is the simplest configuration - a water-covered planet with no orography and constant (zonally symmetric) forcing. This is ideal for learning the model and testing new physics:
from jcm.model import Model
from jcm.physics.speedy.speedy_coords import get_speedy_coords
# Create a model with default aquaplanet configuration
model = Model(
coords=get_speedy_coords(), # T31 spectral resolution with 8 vertical levels
time_step=30.0 # minutes
)
# Run a 120-day simulation
predictions = model.run(
save_interval=10.0, # save every 10 days
total_time=120.0 # total simulation time in days
)
# Convert output to xarray Dataset for analysis
ds = predictions.to_xarray()
print(ds)
This creates a T31 spectral resolution model (96x48 grid points) with 8 vertical levels using the SPEEDY physics package. The default forcing includes zonally symmetric sea surface temperatures and no land.
Realistic Simulation¶
For a more realistic simulation with orography and time-varying boundary conditions, you can load data from files:
from jcm.model import Model
from jcm.terrain import TerrainData
from jcm.forcing import ForcingData
from importlib import resources
coords = get_speedy_coords() # T31 spectral resolution with 8 vertical levels
# Load realistic orography and land-sea mask, interpolated to T31 grid
data_dir = resources.files("jcm.data.bc.t30.clim")
terrain_file = data_dir / "terrain.nc"
terrain = TerrainData.from_file(terrain_file, coords=coords)
# Load realistic forcing data (SST, sea ice, soil moisture, etc.) interpolated to T31 grid
forcing_file = data_dir / "forcing.nc"
forcing = ForcingData.from_file(forcing_file, coords=coords)
# Create model with realistic configuration
model = Model(
coords,
time_step=30.0,
terrain=terrain
)
# Run simulation
predictions = model.run(
forcing=forcing,
save_interval=5.0, # save every 5 days
total_time=30.0 # 30-day simulation
)
# Convert to xarray and save
ds = predictions.to_xarray()
ds.to_netcdf("output.nc")
Customizing the Model¶
You can customize various aspects of the model:
Resolution: Change the horizontal and vertical resolution
from jcm.terrain import TerrainData
from jcm.physics.speedy.speedy_coords import get_speedy_coords
# Higher resolution: T85 (256x128 grid)
coords = get_speedy_coords(spectral_truncation=85)
terrain = TerrainData.aquaplanet(coords=coords)
model = Model(
coords=coords,
time_step=20.0, # smaller timestep for stability
terrain=terrain
)
Physics: Use different physics packages or configurations
from jcm.physics.speedy.speedy_physics import SpeedyPhysics
from jcm.physics.speedy.params import Parameters
from jcm.physics.speedy.speedy_coords import get_speedy_coords
# Customize physics parameters
params = Parameters.default()
params = params.replace(...) # modify parameters as needed
physics = SpeedyPhysics(parameters=params)
model = Model(
coords=get_speedy_coords(),
time_step=30.0,
physics=physics
)
Initial Conditions: Start from a specific state
from jcm.physics_interface import PhysicsState
# Create or load initial state
# initial_state = PhysicsState(...)
predictions = model.run(
initial_state=initial_state,
save_interval=1.0,
total_time=10.0
)
Multi-Device Parallelization¶
JCM supports multi-device parallelization using JAX’s SPMD (Single Program Multiple Data) sharding. This allows you to split computation across multiple GPUs or TPUs for faster execution, especially useful for higher resolution simulations.
If you don’t specify spmd_mesh when building your coords, JCM runs on a single device by default. This is the recommended approach for smaller resolutions (T31, T42) or when you only have a single GPU/TPU available.
Basic Concepts¶
SPMD Mesh: Defines how to partition data across devices. The mesh has three dimensions corresponding to (x, y, z) or (longitude, latitude, vertical).
Sharding Strategy: Typically, for SPEEDY Physics simulations, you want to shard the longitude dimension first since it usually has the most grid points. For Physics implementations with more layers (e.g. 32 or 64 layers) however, you may find that sharding the dycore in the vertical dimension to be most effective. Future implementations may allow for more flexible sharding strategies.
Enabling Parallelization¶
To enable multi-device parallelization, pass spmd_mesh to the coords helper
(e.g. get_speedy_coords or get_coords) and build the Model with those coords:
import jax
from jcm.model import Model
from jcm.physics.speedy.speedy_coords import get_speedy_coords
# Check available devices
print(f"Available devices: {jax.devices()}")
print(f"Number of devices: {len(jax.devices())}")
# Define a mesh to split longitude across 4 devices
# Mesh shape (4, 1, 1) means:
# - Split longitude dimension across 4 devices
# - Don't split latitude (1)
# - Don't split vertical (1)
coords = get_speedy_coords(spmd_mesh=(4, 1, 1))
model = Model(coords=coords)
predictions = model.run(save_interval=5.0, total_time=30.0)
Mesh Configuration Guidelines¶
The product of mesh dimensions must equal the number of available devices:
(4, 1, 1): Split longitude across 4 devices(2, 2, 1): Split longitude (2) and latitude (2) across 4 devices total(8, 1, 1): Split longitude across 8 devices (for higher resolutions)
Rules of thumb:
Product of mesh dimensions = number of devices
Longitude (x) usually has most grid points → split first
Higher resolutions (T85+) benefit more from sharding
Analyzing Output¶
The model output is a Predictions object containing the model state trajectory. Convert it to xarray for analysis:
import matplotlib.pyplot as plt
# Convert to xarray Dataset
ds = predictions.to_xarray()
# Print variables
print(ds.data_vars)
# Plot surface temperature evolution
ds['temperature'].isel(level=7).mean(dim='lon').plot()
plt.title('Zonal Mean Surface Temperature')
plt.show()
# Calculate global mean quantities
global_mean_temp = ds['temperature'].weighted(
ds['lat'].pipe(lambda x: np.cos(np.deg2rad(x)))
).mean(dim=['lon', 'lat'])
Next Steps¶
See Architecture & Design to understand the model architecture
See API for detailed API documentation
Check example notebooks in the
notebooks/directory of the GitHub repoRead Developer Guide for contribution guidelines