Getting Started =============== .. _installation: Installation ------------ To use JAX-GCM, first install it: .. code-block:: console $ pip install jcm or for the development version: .. code-block:: console $ git clone https://github.com/climate-analytics-lab/jax-gcm.git $ cd jax-gcm $ pip install -e . Requirements ^^^^^^^^^^^^ - Python ≥ 3.11 - JAX - Dinosaur (dynamical core) - XArray (for I/O and data handling) See ``requirements.txt`` for the complete list of dependencies. Quick Start Examples -------------------- Aquaplanet Simulation ^^^^^^^^^^^^^^^^^^^^^ An aquaplanet simulation is the simplest configuration - a water-covered planet with no orography and constant (zonally symmetric) forcing. This is ideal for learning the model and testing new physics: .. code-block:: python from jcm.model import Model from jcm.physics.speedy.speedy_coords import get_speedy_coords # Create a model with default aquaplanet configuration model = Model( coords=get_speedy_coords(), # T31 spectral resolution with 8 vertical levels time_step=30.0 # minutes ) # Run a 120-day simulation predictions = model.run( save_interval=10.0, # save every 10 days total_time=120.0 # total simulation time in days ) # Convert output to xarray Dataset for analysis ds = predictions.to_xarray() print(ds) This creates a T31 spectral resolution model (96x48 grid points) with 8 vertical levels using the SPEEDY physics package. The default forcing includes zonally symmetric sea surface temperatures and no land. Realistic Simulation ^^^^^^^^^^^^^^^^^^^^ For a more realistic simulation with orography and time-varying boundary conditions, you can load data from files: .. code-block:: python from jcm.model import Model from jcm.terrain import TerrainData from jcm.forcing import ForcingData from importlib import resources coords = get_speedy_coords() # T31 spectral resolution with 8 vertical levels # Load realistic orography and land-sea mask, interpolated to T31 grid data_dir = resources.files("jcm.data.bc.t30.clim") terrain_file = data_dir / "terrain.nc" terrain = TerrainData.from_file(terrain_file, coords=coords) # Load realistic forcing data (SST, sea ice, soil moisture, etc.) interpolated to T31 grid forcing_file = data_dir / "forcing.nc" forcing = ForcingData.from_file(forcing_file, coords=coords) # Create model with realistic configuration model = Model( coords, time_step=30.0, terrain=terrain ) # Run simulation predictions = model.run( forcing=forcing, save_interval=5.0, # save every 5 days total_time=30.0 # 30-day simulation ) # Convert to xarray and save ds = predictions.to_xarray() ds.to_netcdf("output.nc") Customizing the Model ^^^^^^^^^^^^^^^^^^^^^ You can customize various aspects of the model: **Resolution**: Change the horizontal and vertical resolution .. code-block:: python from jcm.terrain import TerrainData from jcm.physics.speedy.speedy_coords import get_speedy_coords # Higher resolution: T85 (256x128 grid) coords = get_speedy_coords(spectral_truncation=85) terrain = TerrainData.aquaplanet(coords=coords) model = Model( coords=coords, time_step=20.0, # smaller timestep for stability terrain=terrain ) **Physics**: Use different physics packages or configurations .. code-block:: python from jcm.physics.speedy.speedy_physics import SpeedyPhysics from jcm.physics.speedy.params import Parameters from jcm.physics.speedy.speedy_coords import get_speedy_coords # Customize physics parameters params = Parameters.default() params = params.replace(...) # modify parameters as needed physics = SpeedyPhysics(parameters=params) model = Model( coords=get_speedy_coords(), time_step=30.0, physics=physics ) **Initial Conditions**: Start from a specific state .. code-block:: python from jcm.physics_interface import PhysicsState # Create or load initial state # initial_state = PhysicsState(...) predictions = model.run( initial_state=initial_state, save_interval=1.0, total_time=10.0 ) Multi-Device Parallelization ----------------------------- JCM supports multi-device parallelization using JAX's SPMD (Single Program Multiple Data) sharding. This allows you to split computation across multiple GPUs or TPUs for faster execution, especially useful for higher resolution simulations. If you don't specify ``spmd_mesh`` when building your coords, JCM runs on a single device by default. This is the recommended approach for smaller resolutions (T31, T42) or when you only have a single GPU/TPU available. Basic Concepts ^^^^^^^^^^^^^^ **SPMD Mesh**: Defines how to partition data across devices. The mesh has three dimensions corresponding to ``(x, y, z)`` or ``(longitude, latitude, vertical)``. **Sharding Strategy**: Typically, for SPEEDY Physics simulations, you want to shard the longitude dimension first since it usually has the most grid points. For Physics implementations with more layers (e.g. 32 or 64 layers) however, you may find that sharding the dycore in the vertical dimension to be most effective. Future implementations may allow for more flexible sharding strategies. Enabling Parallelization ^^^^^^^^^^^^^^^^^^^^^^^^ To enable multi-device parallelization, pass ``spmd_mesh`` to the coords helper (e.g. ``get_speedy_coords`` or ``get_coords``) and build the ``Model`` with those coords: .. code-block:: python import jax from jcm.model import Model from jcm.physics.speedy.speedy_coords import get_speedy_coords # Check available devices print(f"Available devices: {jax.devices()}") print(f"Number of devices: {len(jax.devices())}") # Define a mesh to split longitude across 4 devices # Mesh shape (4, 1, 1) means: # - Split longitude dimension across 4 devices # - Don't split latitude (1) # - Don't split vertical (1) coords = get_speedy_coords(spmd_mesh=(4, 1, 1)) model = Model(coords=coords) predictions = model.run(save_interval=5.0, total_time=30.0) Mesh Configuration Guidelines ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The product of mesh dimensions must equal the number of available devices: - ``(4, 1, 1)``: Split longitude across 4 devices - ``(2, 2, 1)``: Split longitude (2) and latitude (2) across 4 devices total - ``(8, 1, 1)``: Split longitude across 8 devices (for higher resolutions) **Rules of thumb:** 1. Product of mesh dimensions = number of devices 2. Longitude (x) usually has most grid points → split first 3. Higher resolutions (T85+) benefit more from sharding Analyzing Output ---------------- The model output is a :py:class:`Predictions` object containing the model state trajectory. Convert it to xarray for analysis: .. code-block:: python import matplotlib.pyplot as plt # Convert to xarray Dataset ds = predictions.to_xarray() # Print variables print(ds.data_vars) # Plot surface temperature evolution ds['temperature'].isel(level=7).mean(dim='lon').plot() plt.title('Zonal Mean Surface Temperature') plt.show() # Calculate global mean quantities global_mean_temp = ds['temperature'].weighted( ds['lat'].pipe(lambda x: np.cos(np.deg2rad(x))) ).mean(dim=['lon', 'lat']) Next Steps ---------- - See :doc:`design` to understand the model architecture - See :doc:`api` for detailed API documentation - Check example notebooks in the ``notebooks/`` directory of the GitHub repo - Read :doc:`developer` for contribution guidelines