# Copyright 2023 DeepMind Technologies Limited. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS-IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Data-structure for storing graphs with typed edges and nodes.""" | |
from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar | |
ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor | |
ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike | |
_T = TypeVar('_T') | |
# All tensors have a "flat_batch_axis", which is similar to the leading | |
# axes of graph_tuples: | |
# * In the case of nodes this is simply a shared node and flat batch axis, with | |
# size corresponding to the total number of nodes in the flattened batch. | |
# * In the case of edges this is simply a shared edge and flat batch axis, with | |
# size corresponding to the total number of edges in the flattened batch. | |
# * In the case of globals this is simply the number of graphs in the flattened | |
# batch. | |
# All shapes may also have any additional leading shape "batch_shape". | |
# Options for building batches are: | |
# * Use a provided "flatten" method that takes a leading `batch_shape` and | |
# it into the flat_batch_axis (this will be useful when using `tf.Dataset` | |
# which supports batching into RaggedTensors, with leading batch shape even | |
# if graphs have different numbers of nodes and edges), so the RaggedBatches | |
# can then be converted into something without ragged dimensions that jax can | |
# use. | |
# * Directly build a "flat batch" using a provided function for batching a list | |
# of graphs (how it is done in `jraph`). | |
class NodeSet(NamedTuple): | |
"""Represents a set of nodes.""" | |
n_node: ArrayLike # [num_flat_graphs] | |
features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape | |
class EdgesIndices(NamedTuple): | |
"""Represents indices to nodes adjacent to the edges.""" | |
senders: ArrayLike # [num_flat_edges] | |
receivers: ArrayLike # [num_flat_edges] | |
class EdgeSet(NamedTuple): | |
"""Represents a set of edges.""" | |
n_edge: ArrayLike # [num_flat_graphs] | |
indices: EdgesIndices | |
features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape | |
class Context(NamedTuple): | |
# `n_graph` always contains ones but it is useful to query the leading shape | |
# in case of graphs without any nodes or edges sets. | |
n_graph: ArrayLike # [num_flat_graphs] | |
features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape | |
class EdgeSetKey(NamedTuple): | |
name: str # Name of the EdgeSet. | |
# Sender node set name and receiver node set name connected by the edge set. | |
node_sets: Tuple[str, str] | |
class TypedGraph(NamedTuple): | |
"""A graph with typed nodes and edges. | |
A typed graph is made of a context, multiple sets of nodes and multiple | |
sets of edges connecting those nodes (as indicated by the EdgeSetKey). | |
""" | |
context: Context | |
nodes: Mapping[str, NodeSet] | |
edges: Mapping[EdgeSetKey, EdgeSet] | |
def edge_key_by_name(self, name: str) -> EdgeSetKey: | |
found_key = [k for k in self.edges.keys() if k.name == name] | |
if len(found_key) != 1: | |
raise KeyError("invalid edge key '{}'. Available edges: [{}]".format( | |
name, ', '.join(x.name for x in self.edges.keys()))) | |
return found_key[0] | |
def edge_by_name(self, name: str) -> EdgeSet: | |
return self.edges[self.edge_key_by_name(name)] | |