# Copyright 2023 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS-IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A predictor that runs multiple graph neural networks on mesh data. It learns to interpolate between the grid and the mesh nodes, with the loss and the rollouts ultimately computed at the grid level. It uses ideas similar to those in Keisler (2022): Reference: https://arxiv.org/pdf/2202.07575.pdf It assumes data across time and level is stacked, and operates only operates in a 2D mesh over latitudes and longitudes. """ from typing import Any, Callable, Mapping, Optional import chex from graphcast import deep_typed_graph_net from graphcast import grid_mesh_connectivity from graphcast import icosahedral_mesh from graphcast import losses from graphcast import model_utils from graphcast import predictor_base from graphcast import typed_graph from graphcast import xarray_jax import jax.numpy as jnp import jraph import numpy as np import xarray Kwargs = Mapping[str, Any] GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple] # https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5 PRESSURE_LEVELS_ERA5_37 = ( 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900, 925, 950, 975, 1000) # https://www.ecmwf.int/en/forecasts/datasets/set-i PRESSURE_LEVELS_HRES_25 = ( 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600, 700, 800, 850, 900, 925, 950, 1000) # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2020MS002203 PRESSURE_LEVELS_WEATHERBENCH_13 = ( 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000) PRESSURE_LEVELS = { 13: PRESSURE_LEVELS_WEATHERBENCH_13, 25: PRESSURE_LEVELS_HRES_25, 37: PRESSURE_LEVELS_ERA5_37, } # The list of all possible atmospheric variables. Taken from: # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Table9 ALL_ATMOSPHERIC_VARS = ( "potential_vorticity", "specific_rain_water_content", "specific_snow_water_content", "geopotential", "temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "vertical_velocity", "vorticity", "divergence", "relative_humidity", "ozone_mass_mixing_ratio", "specific_cloud_liquid_water_content", "specific_cloud_ice_water_content", "fraction_of_cloud_cover", ) TARGET_SURFACE_VARS = ( "2m_temperature", "mean_sea_level_pressure", "10m_v_component_of_wind", "10m_u_component_of_wind", "total_precipitation_6hr", ) TARGET_SURFACE_NO_PRECIP_VARS = ( "2m_temperature", "mean_sea_level_pressure", "10m_v_component_of_wind", "10m_u_component_of_wind", ) TARGET_ATMOSPHERIC_VARS = ( "temperature", "geopotential", "u_component_of_wind", "v_component_of_wind", "vertical_velocity", "specific_humidity", ) TARGET_ATMOSPHERIC_NO_W_VARS = ( "temperature", "geopotential", "u_component_of_wind", "v_component_of_wind", "specific_humidity", ) EXTERNAL_FORCING_VARS = ( "toa_incident_solar_radiation", ) GENERATED_FORCING_VARS = ( "year_progress_sin", "year_progress_cos", "day_progress_sin", "day_progress_cos", ) FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS STATIC_VARS = ( "geopotential_at_surface", "land_sea_mask", ) @chex.dataclass(frozen=True, eq=True) class TaskConfig: """Defines inputs and targets on which a model is trained and/or evaluated.""" input_variables: tuple[str, ...] # Target variables which the model is expected to predict. target_variables: tuple[str, ...] forcing_variables: tuple[str, ...] pressure_levels: tuple[int, ...] input_duration: str TASK = TaskConfig( input_variables=( TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + STATIC_VARS), target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, forcing_variables=FORCING_VARS, pressure_levels=PRESSURE_LEVELS_ERA5_37, input_duration="12h", ) TASK_13 = TaskConfig( input_variables=( TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + STATIC_VARS), target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, forcing_variables=FORCING_VARS, pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, input_duration="12h", ) TASK_13_PRECIP_OUT = TaskConfig( input_variables=( TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + STATIC_VARS), target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, forcing_variables=FORCING_VARS, pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, input_duration="12h", ) @chex.dataclass(frozen=True, eq=True) class ModelConfig: """Defines the architecture of the GraphCast neural network architecture. Properties: resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0). mesh_size: How many refinements to do on the multi-mesh. gnn_msg_steps: How many Graph Network message passing steps to do. latent_size: How many latent features to include in the various MLPs. hidden_layers: How many hidden layers for each MLP. radius_query_fraction_edge_length: Scalar that will be multiplied by the length of the longest edge of the finest mesh to define the radius of connectivity to use in the Grid2Mesh graph. Reasonable values are between 0.6 and 1. 0.6 reduces the number of grid points feeding into multiple mesh nodes and therefore reduces edge count and memory use, but 1 gives better predictions. mesh2grid_edge_normalization_factor: Allows explicitly controlling edge normalization for mesh2grid edges. If None, defaults to max edge length. This supports using pre-trained model weights with a different graph structure to what it was trained on. """ resolution: float mesh_size: int latent_size: int gnn_msg_steps: int hidden_layers: int radius_query_fraction_edge_length: float mesh2grid_edge_normalization_factor: Optional[float] = None @chex.dataclass(frozen=True, eq=True) class CheckPoint: params: dict[str, Any] model_config: ModelConfig task_config: TaskConfig description: str license: str class GraphCast(predictor_base.Predictor): """GraphCast Predictor. The model works on graphs that take into account: * Mesh nodes: nodes for the vertices of the mesh. * Grid nodes: nodes for the points of the grid. * Nodes: When referring to just "nodes", this means the joint set of both mesh nodes, concatenated with grid nodes. The model works with 3 graphs: * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly bipartite with edges going from grid nodes to mesh nodes using a fixed radius query. The grid2mesh_gnn will operate in this graph. The output of this stage will be a latent representation for the mesh nodes, and a latent representation for the grid nodes. * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will operate in this graph. It will update the latent state of the mesh nodes only. * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly bipartite with edges going from mesh nodes to grid nodes such that each grid nodes is connected to 3 nodes of the mesh triangular face that contains the grid points. The mesh2grid_gnn will operate in this graph. It will process the updated latent state of the mesh nodes, and the latent state of the grid nodes, to produce the final output for the grid nodes. The model is built on top of `TypedGraph`s so the different types of nodes and edges can be stored and treated separately. """ def __init__(self, model_config: ModelConfig, task_config: TaskConfig): """Initializes the predictor.""" self._spatial_features_kwargs = dict( add_node_positions=False, add_node_latitude=True, add_node_longitude=True, add_relative_positions=True, relative_longitude_local_coordinates=True, relative_latitude_local_coordinates=True, ) # Specification of the multimesh. self._meshes = ( icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( splits=model_config.mesh_size)) # Encoder, which moves data from the grid to the mesh with a single message # passing step. self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( embed_nodes=True, # Embed raw features of the grid and mesh nodes. embed_edges=True, # Embed raw features of the grid2mesh edges. edge_latent_size=dict(grid2mesh=model_config.latent_size), node_latent_size=dict( mesh_nodes=model_config.latent_size, grid_nodes=model_config.latent_size), mlp_hidden_size=model_config.latent_size, mlp_num_hidden_layers=model_config.hidden_layers, num_message_passing_steps=1, use_layer_norm=True, include_sent_messages_in_node_update=False, activation="swish", f32_aggregation=True, aggregate_normalization=None, name="grid2mesh_gnn", ) # Processor, which performs message passing on the multi-mesh. self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( embed_nodes=False, # Node features already embdded by previous layers. embed_edges=True, # Embed raw features of the multi-mesh edges. node_latent_size=dict(mesh_nodes=model_config.latent_size), edge_latent_size=dict(mesh=model_config.latent_size), mlp_hidden_size=model_config.latent_size, mlp_num_hidden_layers=model_config.hidden_layers, num_message_passing_steps=model_config.gnn_msg_steps, use_layer_norm=True, include_sent_messages_in_node_update=False, activation="swish", f32_aggregation=False, name="mesh_gnn", ) num_surface_vars = len( set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS)) num_atmospheric_vars = len( set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS)) num_outputs = (num_surface_vars + len(task_config.pressure_levels) * num_atmospheric_vars) # Decoder, which moves data from the mesh back into the grid with a single # message passing step. self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet( # Require a specific node dimensionaly for the grid node outputs. node_output_size=dict(grid_nodes=num_outputs), embed_nodes=False, # Node features already embdded by previous layers. embed_edges=True, # Embed raw features of the mesh2grid edges. edge_latent_size=dict(mesh2grid=model_config.latent_size), node_latent_size=dict( mesh_nodes=model_config.latent_size, grid_nodes=model_config.latent_size), mlp_hidden_size=model_config.latent_size, mlp_num_hidden_layers=model_config.hidden_layers, num_message_passing_steps=1, use_layer_norm=True, include_sent_messages_in_node_update=False, activation="swish", f32_aggregation=False, name="mesh2grid_gnn", ) # Obtain the query radius in absolute units for the unit-sphere for the # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`. self._query_radius = (_get_max_edge_distance(self._finest_mesh) * model_config.radius_query_fraction_edge_length) self._mesh2grid_edge_normalization_factor = ( model_config.mesh2grid_edge_normalization_factor ) # Other initialization is delayed until the first call (`_maybe_init`) # when we get some sample data so we know the lat/lon values. self._initialized = False # A "_init_mesh_properties": # This one could be initialized at init but we delay it for consistency too. self._num_mesh_nodes = None # num_mesh_nodes self._mesh_nodes_lat = None # [num_mesh_nodes] self._mesh_nodes_lon = None # [num_mesh_nodes] # A "_init_grid_properties": self._grid_lat = None # [num_lat_points] self._grid_lon = None # [num_lon_points] self._num_grid_nodes = None # num_lat_points * num_lon_points self._grid_nodes_lat = None # [num_grid_nodes] self._grid_nodes_lon = None # [num_grid_nodes] # A "_init_{grid2mesh,processor,mesh2grid}_graph" self._grid2mesh_graph_structure = None self._mesh_graph_structure = None self._mesh2grid_graph_structure = None @property def _finest_mesh(self): return self._meshes[-1] def __call__(self, inputs: xarray.Dataset, targets_template: xarray.Dataset, forcings: xarray.Dataset, is_training: bool = False, ) -> xarray.Dataset: self._maybe_init(inputs) # Convert all input data into flat vectors for each of the grid nodes. # xarray (batch, time, lat, lon, level, multiple vars, forcings) # -> [num_grid_nodes, batch, num_channels] grid_node_features = self._inputs_to_grid_node_features(inputs, forcings) # Transfer data for the grid to the mesh, # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size] (latent_mesh_nodes, latent_grid_nodes ) = self._run_grid2mesh_gnn(grid_node_features) # Run message passing in the multimesh. # [num_mesh_nodes, batch, latent_size] updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes) # Transfer data frome the mesh to the grid. # [num_grid_nodes, batch, output_size] output_grid_nodes = self._run_mesh2grid_gnn( updated_latent_mesh_nodes, latent_grid_nodes) # Conver output flat vectors for the grid nodes to the format of the output. # [num_grid_nodes, batch, output_size] -> # xarray (batch, one time step, lat, lon, level, multiple vars) return self._grid_node_outputs_to_prediction( output_grid_nodes, targets_template) def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray self, inputs: xarray.Dataset, targets: xarray.Dataset, forcings: xarray.Dataset, ) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: # Forward pass. predictions = self( inputs, targets_template=targets, forcings=forcings, is_training=True) # Compute loss. loss = losses.weighted_mse_per_level( predictions, targets, per_variable_weights={ # Any variables not specified here are weighted as 1.0. # A single-level variable, but an important headline variable # and also one which we have struggled to get good performance # on at short lead times, so leaving it weighted at 1.0, equal # to the multi-level variables: "2m_temperature": 1.0, # New single-level variables, which we don't weight too highly # to avoid hurting performance on other variables. "10m_u_component_of_wind": 0.1, "10m_v_component_of_wind": 0.1, "mean_sea_level_pressure": 0.1, "total_precipitation_6hr": 0.1, }) return loss, predictions # pytype: disable=bad-return-type # jax-ndarray def loss( # pytype: disable=signature-mismatch # jax-ndarray self, inputs: xarray.Dataset, targets: xarray.Dataset, forcings: xarray.Dataset, ) -> predictor_base.LossAndDiagnostics: loss, _ = self.loss_and_predictions(inputs, targets, forcings) return loss # pytype: disable=bad-return-type # jax-ndarray def _maybe_init(self, sample_inputs: xarray.Dataset): """Inits everything that has a dependency on the input coordinates.""" if not self._initialized: self._init_mesh_properties() self._init_grid_properties( grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon) self._grid2mesh_graph_structure = self._init_grid2mesh_graph() self._mesh_graph_structure = self._init_mesh_graph() self._mesh2grid_graph_structure = self._init_mesh2grid_graph() self._initialized = True def _init_mesh_properties(self): """Inits static properties that have to do with mesh nodes.""" self._num_mesh_nodes = self._finest_mesh.vertices.shape[0] mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( self._finest_mesh.vertices[:, 0], self._finest_mesh.vertices[:, 1], self._finest_mesh.vertices[:, 2]) ( mesh_nodes_lat, mesh_nodes_lon, ) = model_utils.spherical_to_lat_lon( phi=mesh_phi, theta=mesh_theta) # Convert to f32 to ensure the lat/lon features aren't in f64. self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): """Inits static properties that have to do with grid nodes.""" self._grid_lat = grid_lat.astype(np.float32) self._grid_lon = grid_lon.astype(np.float32) # Initialized the counters. self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] # Initialize lat and lon for the grid. grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph: """Build Grid2Mesh graph.""" # Create some edges according to distance between mesh and grid nodes. assert self._grid_lat is not None and self._grid_lon is not None (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( grid_latitude=self._grid_lat, grid_longitude=self._grid_lon, mesh=self._finest_mesh, radius=self._query_radius) # Edges sending info from grid to mesh. senders = grid_indices receivers = mesh_indices # Precompute structural node and edge features according to config options. # Structural features are those that depend on the fixed values of the # latitude and longitudes of the nodes. (senders_node_features, receivers_node_features, edge_features) = model_utils.get_bipartite_graph_spatial_features( senders_node_lat=self._grid_nodes_lat, senders_node_lon=self._grid_nodes_lon, receivers_node_lat=self._mesh_nodes_lat, receivers_node_lon=self._mesh_nodes_lon, senders=senders, receivers=receivers, edge_normalization_factor=None, **self._spatial_features_kwargs, ) n_grid_node = np.array([self._num_grid_nodes]) n_mesh_node = np.array([self._num_mesh_nodes]) n_edge = np.array([mesh_indices.shape[0]]) grid_node_set = typed_graph.NodeSet( n_node=n_grid_node, features=senders_node_features) mesh_node_set = typed_graph.NodeSet( n_node=n_mesh_node, features=receivers_node_features) edge_set = typed_graph.EdgeSet( n_edge=n_edge, indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), features=edge_features) nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} edges = { typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")): edge_set } grid2mesh_graph = typed_graph.TypedGraph( context=typed_graph.Context(n_graph=np.array([1]), features=()), nodes=nodes, edges=edges) return grid2mesh_graph def _init_mesh_graph(self) -> typed_graph.TypedGraph: """Build Mesh graph.""" merged_mesh = icosahedral_mesh.merge_meshes(self._meshes) # Work simply on the mesh edges. senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces) # Precompute structural node and edge features according to config options. # Structural features are those that depend on the fixed values of the # latitude and longitudes of the nodes. assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None node_features, edge_features = model_utils.get_graph_spatial_features( node_lat=self._mesh_nodes_lat, node_lon=self._mesh_nodes_lon, senders=senders, receivers=receivers, **self._spatial_features_kwargs, ) n_mesh_node = np.array([self._num_mesh_nodes]) n_edge = np.array([senders.shape[0]]) assert n_mesh_node == len(node_features) mesh_node_set = typed_graph.NodeSet( n_node=n_mesh_node, features=node_features) edge_set = typed_graph.EdgeSet( n_edge=n_edge, indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), features=edge_features) nodes = {"mesh_nodes": mesh_node_set} edges = { typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set } mesh_graph = typed_graph.TypedGraph( context=typed_graph.Context(n_graph=np.array([1]), features=()), nodes=nodes, edges=edges) return mesh_graph def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph: """Build Mesh2Grid graph.""" # Create some edges according to how the grid nodes are contained by # mesh triangles. (grid_indices, mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( grid_latitude=self._grid_lat, grid_longitude=self._grid_lon, mesh=self._finest_mesh) # Edges sending info from mesh to grid. senders = mesh_indices receivers = grid_indices # Precompute structural node and edge features according to config options. assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None (senders_node_features, receivers_node_features, edge_features) = model_utils.get_bipartite_graph_spatial_features( senders_node_lat=self._mesh_nodes_lat, senders_node_lon=self._mesh_nodes_lon, receivers_node_lat=self._grid_nodes_lat, receivers_node_lon=self._grid_nodes_lon, senders=senders, receivers=receivers, edge_normalization_factor=self._mesh2grid_edge_normalization_factor, **self._spatial_features_kwargs, ) n_grid_node = np.array([self._num_grid_nodes]) n_mesh_node = np.array([self._num_mesh_nodes]) n_edge = np.array([senders.shape[0]]) grid_node_set = typed_graph.NodeSet( n_node=n_grid_node, features=receivers_node_features) mesh_node_set = typed_graph.NodeSet( n_node=n_mesh_node, features=senders_node_features) edge_set = typed_graph.EdgeSet( n_edge=n_edge, indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), features=edge_features) nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} edges = { typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")): edge_set } mesh2grid_graph = typed_graph.TypedGraph( context=typed_graph.Context(n_graph=np.array([1]), features=()), nodes=nodes, edges=edges) return mesh2grid_graph def _run_grid2mesh_gnn(self, grid_node_features: chex.Array, ) -> tuple[chex.Array, chex.Array]: """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes.""" # Concatenate node structural features with input features. batch_size = grid_node_features.shape[1] grid2mesh_graph = self._grid2mesh_graph_structure assert grid2mesh_graph is not None grid_nodes = grid2mesh_graph.nodes["grid_nodes"] mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"] new_grid_nodes = grid_nodes._replace( features=jnp.concatenate([ grid_node_features, _add_batch_second_axis( grid_nodes.features.astype(grid_node_features.dtype), batch_size) ], axis=-1)) # To make sure capacity of the embedded is identical for the grid nodes and # the mesh nodes, we also append some dummy zero input features for the # mesh nodes. dummy_mesh_node_features = jnp.zeros( (self._num_mesh_nodes,) + grid_node_features.shape[1:], dtype=grid_node_features.dtype) new_mesh_nodes = mesh_nodes._replace( features=jnp.concatenate([ dummy_mesh_node_features, _add_batch_second_axis( mesh_nodes.features.astype(dummy_mesh_node_features.dtype), batch_size) ], axis=-1)) # Broadcast edge structural features to the required batch size. grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh") edges = grid2mesh_graph.edges[grid2mesh_edges_key] new_edges = edges._replace( features=_add_batch_second_axis( edges.features.astype(dummy_mesh_node_features.dtype), batch_size)) input_graph = self._grid2mesh_graph_structure._replace( edges={grid2mesh_edges_key: new_edges}, nodes={ "grid_nodes": new_grid_nodes, "mesh_nodes": new_mesh_nodes }) # Run the GNN. grid2mesh_out = self._grid2mesh_gnn(input_graph) latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features return latent_mesh_nodes, latent_grid_nodes def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array: """Runs the mesh_gnn, extracting updated latent mesh nodes.""" # Add the structural edge features of this graph. Note we don't need # to add the structural node features, because these are already part of # the latent state, via the original Grid2Mesh gnn, however, we need # the edge ones, because it is the first time we are seeing this particular # set of edges. batch_size = latent_mesh_nodes.shape[1] mesh_graph = self._mesh_graph_structure assert mesh_graph is not None mesh_edges_key = mesh_graph.edge_key_by_name("mesh") edges = mesh_graph.edges[mesh_edges_key] # We are assuming here that the mesh gnn uses a single set of edge keys # named "mesh" for the edges and that it uses a single set of nodes named # "mesh_nodes" msg = ("The setup currently requires to only have one kind of edge in the" " mesh GNN.") assert len(mesh_graph.edges) == 1, msg new_edges = edges._replace( features=_add_batch_second_axis( edges.features.astype(latent_mesh_nodes.dtype), batch_size)) nodes = mesh_graph.nodes["mesh_nodes"] nodes = nodes._replace(features=latent_mesh_nodes) input_graph = mesh_graph._replace( edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes}) # Run the GNN. return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features def _run_mesh2grid_gnn(self, updated_latent_mesh_nodes: chex.Array, latent_grid_nodes: chex.Array, ) -> chex.Array: """Runs the mesh2grid_gnn, extracting the output grid nodes.""" # Add the structural edge features of this graph. Note we don't need # to add the structural node features, because these are already part of # the latent state, via the original Grid2Mesh gnn, however, we need # the edge ones, because it is the first time we are seeing this particular # set of edges. batch_size = updated_latent_mesh_nodes.shape[1] mesh2grid_graph = self._mesh2grid_graph_structure assert mesh2grid_graph is not None mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"] grid_nodes = mesh2grid_graph.nodes["grid_nodes"] new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes) new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes) mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid") edges = mesh2grid_graph.edges[mesh2grid_key] new_edges = edges._replace( features=_add_batch_second_axis( edges.features.astype(latent_grid_nodes.dtype), batch_size)) input_graph = mesh2grid_graph._replace( edges={mesh2grid_key: new_edges}, nodes={ "mesh_nodes": new_mesh_nodes, "grid_nodes": new_grid_nodes }) # Run the GNN. output_graph = self._mesh2grid_gnn(input_graph) output_grid_nodes = output_graph.nodes["grid_nodes"].features return output_grid_nodes def _inputs_to_grid_node_features( self, inputs: xarray.Dataset, forcings: xarray.Dataset, ) -> chex.Array: """xarrays -> [num_grid_nodes, batch, num_channels].""" # xarray `Dataset` (batch, time, lat, lon, level, multiple vars) # to xarray `DataArray` (batch, lat, lon, channels) stacked_inputs = model_utils.dataset_to_stacked(inputs) stacked_forcings = model_utils.dataset_to_stacked(forcings) stacked_inputs = xarray.concat( [stacked_inputs, stacked_forcings], dim="channels") # xarray `DataArray` (batch, lat, lon, channels) # to single numpy array with shape [lat_lon_node, batch, channels] grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes( stacked_inputs) return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape( (-1,) + grid_xarray_lat_lon_leading.data.shape[2:]) def _grid_node_outputs_to_prediction( self, grid_node_outputs: chex.Array, targets_template: xarray.Dataset, ) -> xarray.Dataset: """[num_grid_nodes, batch, num_outputs] -> xarray.""" # numpy array with shape [lat_lon_node, batch, channels] # to xarray `DataArray` (batch, lat, lon, channels) assert self._grid_lat is not None and self._grid_lon is not None grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0]) grid_outputs_lat_lon_leading = grid_node_outputs.reshape( grid_shape + grid_node_outputs.shape[1:]) dims = ("lat", "lon", "batch", "channels") grid_xarray_lat_lon_leading = xarray_jax.DataArray( data=grid_outputs_lat_lon_leading, dims=dims) grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading) # xarray `DataArray` (batch, lat, lon, channels) # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars) return model_utils.stacked_to_dataset( grid_xarray.variable, targets_template) def _add_batch_second_axis(data, batch_size): # data [leading_dim, trailing_dim] assert data.ndim == 2 ones = jnp.ones([batch_size, 1], dtype=data.dtype) return data[:, None] * ones # [leading_dim, batch, trailing_dim] def _get_max_edge_distance(mesh): senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) edge_distances = np.linalg.norm( mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) return edge_distances.max()