diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f572c0..92471ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you! ### Added - ci: hpc-config, CODEOWNERS (#49) +- feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30) +- feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30) +- feat: New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). (#30) +- feat: Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. (#30) ### Changed - ci: small fixes and updates pre-commit, downsteam-ci (#49) @@ -20,7 +24,7 @@ Keep it human-readable, your future self will thank you! ### Added -- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere. +- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere - Inspection tools: interactive plots, and distribution plots of edge & node attributes. diff --git a/docs/usage/getting_started.rst b/docs/usage/getting_started.rst index b4998e1..4c322a2 100644 --- a/docs/usage/getting_started.rst +++ b/docs/usage/getting_started.rst @@ -54,7 +54,7 @@ following command: .. code:: console - $ anemoi-graphs inspect graph.pt + $ anemoi-graphs inspect graph.pt output_plots This will generate the following graph: diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 9f10c81..17a46f7 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -41,10 +41,14 @@ def generate_graph(self) -> HeteroData: graph, nodes_cfg.get("attributes", {}) ) - for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( - graph, edges_cfg.get("attributes", {}) - ) + for edges_cfg in self.config.get("edges", {}): + graph = instantiate( + edges_cfg.edge_builder, + edges_cfg.source_name, + edges_cfg.target_name, + source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), + target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), + ).update_graph(graph, edges_cfg.get("attributes", {})) return graph diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 7f34390..83ea17e 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -9,15 +9,18 @@ import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from scipy.sparse import coo_matrix from sklearn.neighbors import NearestNeighbors from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage from anemoi.graphs import EARTH_RADIUS -from anemoi.graphs.generate import hexagonal -from anemoi.graphs.generate import icosahedral -from anemoi.graphs.nodes.builder import HexNodes -from anemoi.graphs.nodes.builder import TriNodes +from anemoi.graphs.generate import hex_icosahedron +from anemoi.graphs.generate import tri_icosahedron +from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes +from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -26,9 +29,17 @@ class BaseEdgeBuilder(ABC): """Base class for edge builders.""" - def __init__(self, source_name: str, target_name: str): + def __init__( + self, + source_name: str, + target_name: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): self.source_name = source_name self.target_name = target_name + self.source_mask_attr_name = source_mask_attr_name + self.target_mask_attr_name = target_mask_attr_name @property def name(self) -> tuple[str, str, str]: @@ -117,15 +128,48 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) - """ graph = self.register_edges(graph) - if attrs_config is None: - return graph - - graph = self.register_attributes(graph, attrs_config) + if attrs_config is not None: + graph = self.register_attributes(graph, attrs_config) return graph -class KNNEdges(BaseEdgeBuilder): +class NodeMaskingMixin: + """Mixin class for masking source/target nodes when building edges.""" + + def get_node_coordinates( + self, source_nodes: NodeStorage, target_nodes: NodeStorage + ) -> tuple[np.ndarray, np.ndarray]: + """Get the node coordinates.""" + source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy() + + if self.source_mask_attr_name is not None: + source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()] + + if self.target_mask_attr_name is not None: + target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()] + + return source_coords, target_coords + + def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage): + if self.target_mask_attr_name is not None: + target_mask = target_nodes[self.target_mask_attr_name].squeeze() + target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0])) + adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row) + + if self.source_mask_attr_name is not None: + source_mask = source_nodes[self.source_mask_attr_name].squeeze() + source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0])) + adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col) + + if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None: + true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0] + adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape) + + return adj_matrix + + +class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes KNN based edges and adds them to the graph. Attributes @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -147,20 +195,27 @@ class KNNEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + num_nearest_neighbours: int, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" self.num_nearest_neighbours = num_nearest_neighbours - def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray) -> np.ndarray: + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray: """Compute the adjacency matrix for the KNN method. Parameters ---------- - source_nodes : np.ndarray + source_nodes : NodeStorage The source nodes. - target_nodes : np.ndarray + target_nodes : NodeStorage The target nodes. Returns @@ -168,6 +223,7 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra np.ndarray The adjacency matrix. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" LOGGER.info( "Using KNN-Edges (with %d nearest neighbours) between %s and %s.", @@ -177,16 +233,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x.numpy()) + nearest_neighbour.fit(source_coords) adj_matrix = nearest_neighbour.kneighbors_graph( - target_nodes.x.numpy(), + target_coords, n_neighbors=self.num_nearest_neighbours, mode="distance", ).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix -class CutOffEdges(BaseEdgeBuilder): +class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin): """Computes cut-off based edges and adds them to the graph. Attributes @@ -197,6 +257,10 @@ class CutOffEdges(BaseEdgeBuilder): The name of the target nodes. cutoff_factor : float Factor to multiply the grid reference distance to get the cut-off radius. + source_mask_attr_name : str | None + The name of the source mask attribute to filter edge connections. + target_mask_attr_name : str | None + The name of the target mask attribute to filter edge connections. Methods ------- @@ -208,8 +272,15 @@ class CutOffEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, cutoff_factor: float) -> None: - super().__init__(source_name, target_name) + def __init__( + self, + source_name: str, + target_name: str, + cutoff_factor: float, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ) -> None: + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float" assert cutoff_factor > 0, "Cutoff factor must be positive" self.cutoff_factor = cutoff_factor @@ -258,6 +329,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor np.ndarray The adjacency matrix. """ + source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes) LOGGER.info( "Using CutOff-Edges (with radius = %.1f km) between %s and %s.", self.radius * EARTH_RADIUS, @@ -266,8 +338,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor ) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) - nearest_neighbour.fit(source_nodes.x) - adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo() + nearest_neighbour.fit(source_coords) + adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo() + + # Post-process the adjacency matrix. Add masked nodes. + adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes) + return adj_matrix @@ -294,61 +370,52 @@ class MultiScaleEdges(BaseEdgeBuilder): Update the graph with the edges. """ - def __init__(self, source_name: str, target_name: str, x_hops: int): + VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes] + + def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): super().__init__(source_name, target_name) assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same." assert isinstance(x_hops, int), "Number of x_hops must be an integer" assert x_hops > 0, "Number of x_hops must be positive" self.x_hops = x_hops + self.node_type = None - def adjacency_from_tri_nodes(self, source_nodes: NodeStorage): - source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph( - source_nodes["_nx_graph"], - resolutions=source_nodes["_resolutions"], + def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], x_hops=self.x_hops, - ) # HeteroData refuses to accept None - - adjmat = nx.to_scipy_sparse_array( - source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo" + area_mask_builder=nodes.get("_area_mask_builder", None), ) - return adjmat - def adjacency_from_hex_nodes(self, source_nodes: NodeStorage): + return nodes - source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph( - source_nodes["_nx_graph"], - resolutions=source_nodes["_resolutions"], + def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: + nodes["_nx_graph"] = hex_icosahedron.add_edges_to_nx_graph( + nodes["_nx_graph"], + resolutions=nodes["_resolutions"], x_hops=self.x_hops, ) - adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") - return adjmat + return nodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - if self.node_type == TriNodes.__name__: - adjmat = self.adjacency_from_tri_nodes(source_nodes) - elif self.node_type == HexNodes.__name__: - adjmat = self.adjacency_from_hex_nodes(source_nodes) + if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: + source_nodes = self.add_edges_from_tri_nodes(source_nodes) + elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: + source_nodes = self.add_edges_from_hex_nodes(source_nodes) else: raise ValueError(f"Invalid node type {self.node_type}") - adjmat = self.post_process_adjmat(source_nodes, adjmat) - - return adjmat + adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") - def post_process_adjmat(self, nodes: NodeStorage, adjmat): - graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])} - sort_func = np.vectorize(graph_sorted.get) - adjmat.row = sort_func(adjmat.row) - adjmat.col = sort_func(adjmat.col) return adjmat def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: - assert ( - graph[self.source_name].node_type == TriNodes.__name__ - or graph[self.source_name].node_type == HexNodes.__name__ - ), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}." - self.node_type = graph[self.source_name].node_type + valid_node_names = [n.__name__ for n in self.VALID_NODES] + assert ( + self.node_type in valid_node_names + ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." return super().update_graph(graph, attrs_config) diff --git a/src/anemoi/graphs/generate/hexagonal.py b/src/anemoi/graphs/generate/hex_icosahedron.py similarity index 63% rename from src/anemoi/graphs/generate/hexagonal.py rename to src/anemoi/graphs/generate/hex_icosahedron.py index 2a9cfe3..3306164 100644 --- a/src/anemoi/graphs/generate/hexagonal.py +++ b/src/anemoi/graphs/generate/hex_icosahedron.py @@ -3,112 +3,98 @@ import h3 import networkx as nx import numpy as np -from sklearn.metrics.pairwise import haversine_distances +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.utils import get_coordinates_ordering -def create_hexagonal_nodes( - resolutions: list[int], - area: dict | None = None, + +def create_hex_nodes( + resolution: int, + area_mask_builder: KNNAreaMaskBuilder | None = None, ) -> tuple[nx.Graph, np.ndarray, list[int]]: - """Creates a global mesh from a refined icosahedro. + """Creates a global mesh from a refined icosahedron. This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each refinement level, a hexagon cell (nodes) has 7 child cells (aperture 7). Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. - area : dict - A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global - mesh. + resolution : int + Level of mesh resolution to consider. + area_mask_builder : KNNAreaMaskBuilder, optional + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- graph : networkx.Graph - The specified graph (nodes & edges). + The specified graph (only nodes) sorted by latitude and longitude. coords_rad : np.ndarray The node coordinates (not ordered) in radians. node_ordering : list[int] - Order of the nodes in the graph to be sorted by latitude and longitude. + Order of the node coordinates to be sorted by latitude and longitude. """ - graph = nx.Graph() + nodes = get_nodes_at_resolution(resolution) - area_kwargs = {"area": area} + coords_rad = np.deg2rad(np.array([h3.h3_to_geo(node) for node in nodes])) - for resolution in resolutions: - graph = add_nodes_for_resolution(graph, resolution, **area_kwargs) + node_ordering = get_coordinates_ordering(coords_rad) - coords = np.deg2rad(np.array([h3.h3_to_geo(node) for node in graph.nodes])) + if area_mask_builder is not None: + aoi_mask = area_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[aoi_mask[node_ordering]] - # Sort nodes by latitude and longitude - node_ordering = np.lexsort(coords.T[::-1], axis=0) + graph = create_nx_graph_from_hex_coords(nodes, node_ordering) - return graph, coords, list(node_ordering) + return graph, coords_rad, list(node_ordering) -def add_nodes_for_resolution( - graph: nx.Graph, - resolution: int, - **area_kwargs: dict | None, -) -> nx.Graph: +def create_nx_graph_from_hex_coords(nodes: list[str], node_ordering: np.ndarray) -> nx.Graph: """Add all nodes at a specified refinement level to a graph. Parameters ---------- - graph : networkx.Graph - The graph to add the nodes. - resolution : int - The H3 refinement level. It can be an integer from 0 to 15. - area_kwargs: dict - Additional arguments to pass to the get_nodes_at_resolution function. + nodes : list[str] + The set of H3 indexes (nodes). + node_ordering: np.ndarray + Order of the node coordinates to be sorted by latitude and longitude. Returns ------- graph : networkx.Graph The graph with the added nodes. """ + graph = nx.Graph() - nodes = get_nodes_at_resolution(resolution, **area_kwargs) - - for idx in nodes: - graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx))) + for node_pos in node_ordering: + h3_idx = nodes[node_pos] + graph.add_node(h3_idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(h3_idx))) return graph def get_nodes_at_resolution( resolution: int, - area: dict | None = None, -) -> set[str]: +) -> list[str]: """Get nodes at a specified refinement level over the entire globe. - If area is not None, it will return the nodes within the specified area - Parameters ---------- resolution : int The H3 refinement level. It can be an integer from 0 to 15. - area : dict - An area as GeoJSON dictionary specifying a polygon. Defaults to None. Returns ------- - nodes : set[str] - The set of H3 indexes at the specified resolution level. + nodes : list[str] + The list of H3 indexes at the specified resolution level. """ - nodes = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution) - - # TODO: AOI not used in the current implementation. - - return nodes + return list(h3.uncompact(h3.get_res0_indexes(), resolution)) def add_edges_to_nx_graph( graph: nx.Graph, resolutions: list[int], x_hops: int = 1, - depth_children: int = 1, + depth_children: int = 0, ) -> nx.Graph: """Adds the edges to the graph. @@ -134,11 +120,7 @@ def add_edges_to_nx_graph( """ graph = add_neighbour_edges(graph, resolutions, x_hops) - graph = add_edges_to_children( - graph, - resolutions, - depth_children, - ) + graph = add_edges_to_children(graph, resolutions, depth_children) return graph @@ -147,6 +129,7 @@ def add_neighbour_edges( refinement_levels: tuple[int], x_hops: int = 1, ) -> nx.Graph: + """Adds edges between neighbours at the specified refinement levels.""" for resolution in refinement_levels: nodes = select_nodes_from_graph_at_resolution(graph, resolution) @@ -155,8 +138,8 @@ def add_neighbour_edges( for idx_neighbour in h3.k_ring(idx, k=x_hops) & set(nodes): graph = add_edge( graph, - h3.h3_to_center_child(idx, refinement_levels[-1]), - h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]), + h3.h3_to_center_child(idx, max(refinement_levels)), + h3.h3_to_center_child(idx_neighbour, max(refinement_levels)), ) return graph @@ -177,17 +160,19 @@ def add_edges_to_children( set of refinement levels depth_children : Optional[int], optional The number of resolution levels to consider for the connections of children. Defaults to 1, which includes - connections up to the next resolution level, by default None + connections up to the next resolution level, by default None. Returns ------- - graph : nx.Graph - graph with the added edges + nx.Graph + Graph with the added edges. """ if depth_children is None: depth_children = len(refinement_levels) + elif depth_children == 0: + return graph - for i_level, resolution_parent in enumerate(refinement_levels[0:-1]): + for i_level, resolution_parent in enumerate(list(sorted(refinement_levels))[0:-1]): parent_nodes = select_nodes_from_graph_at_resolution(graph, resolution_parent) for parent_idx in parent_nodes: @@ -203,9 +188,11 @@ def add_edges_to_children( return graph -def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int): - parent_nodes = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution] - return parent_nodes +def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int) -> set[str]: + """Select nodes from a graph at a specified resolution level.""" + nodes_at_lower_resolution = [n for n in h3.compact(graph.nodes) if h3.h3_get_resolution(n) <= resolution] + nodes_at_resolution = h3.uncompact(nodes_at_lower_resolution, resolution) + return nodes_at_resolution def add_edge( @@ -235,10 +222,6 @@ def add_edge( return graph if source_node_h3_idx != target_node_h3_idx: - source_location = np.deg2rad(h3.h3_to_geo(source_node_h3_idx)) - target_location = np.deg2rad(h3.h3_to_geo(target_node_h3_idx)) - graph.add_edge( - source_node_h3_idx, target_node_h3_idx, weight=haversine_distances([source_location, target_location])[0][1] - ) + graph.add_edge(source_node_h3_idx, target_node_h3_idx) return graph diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py new file mode 100644 index 0000000..1b804f0 --- /dev/null +++ b/src/anemoi/graphs/generate/masks.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import logging + +import numpy as np +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs import EARTH_RADIUS + +LOGGER = logging.getLogger(__name__) + + +class KNNAreaMaskBuilder: + """Class to build a mask based on distance to masked reference nodes using KNN. + + Attributes + ---------- + nearest_neighbour : NearestNeighbors + Nearest neighbour object to compute the KNN. + margin_radius_km : float + Maximum distance to the reference nodes to consider a node as valid, in kilometers. Defaults to 100 km. + reference_node_name : str + Name of the reference nodes in the graph to consider for the Area Mask. + mask_attr_name : str + Name of a node to attribute to mask the reference nodes, if desired. Defaults to consider all reference nodes. + + Methods + ------- + fit_coords(coords_rad: np.ndarray) + Fit the KNN model to the coordinates in radians. + fit(graph: HeteroData) + Fit the KNN model to the reference nodes. + get_mask(coords_rad: np.ndarray) -> np.ndarray + Get the mask for the nodes based on the distance to the reference nodes. + """ + + def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str | None = None): + assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number." + assert margin_radius_km > 0, "The margin radius must be positive." + + self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) + self.margin_radius_km = margin_radius_km + self.reference_node_name = reference_node_name + self.mask_attr_name = mask_attr_name + + def get_reference_coords(self, graph: HeteroData) -> np.ndarray: + """Retrive coordinates from the reference nodes.""" + assert ( + self.reference_node_name in graph.node_types + ), f'Reference node "{self.reference_node_name}" not found in the graph.' + + coords_rad = graph[self.reference_node_name].x.numpy() + if self.mask_attr_name is not None: + assert ( + self.mask_attr_name in graph[self.reference_node_name].node_attrs() + ), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.' + mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() + coords_rad = coords_rad[mask] + + return coords_rad + + def fit_coords(self, coords_rad: np.ndarray): + """Fit the KNN model to the coordinates in radians.""" + self.nearest_neighbour.fit(coords_rad) + + def fit(self, graph: HeteroData): + """Fit the KNN model to the nodes of interest.""" + # Prepare string for logging + reference_mask_str = self.reference_node_name + if self.mask_attr_name is not None: + reference_mask_str += f" ({self.mask_attr_name})" + + # Fit to the reference nodes + coords_rad = self.get_reference_coords(graph) + self.fit_coords(coords_rad) + + LOGGER.info( + 'Fitting %s with %d reference nodes from "%s".', + self.__class__.__name__, + len(coords_rad), + reference_mask_str, + ) + + def get_mask(self, coords_rad: np.ndarray) -> np.ndarray: + """Compute a mask based on the distance to the reference nodes.""" + + neigh_dists, _ = self.nearest_neighbour.kneighbors(coords_rad, n_neighbors=1) + mask = neigh_dists[:, 0] * EARTH_RADIUS <= self.margin_radius_km + return mask diff --git a/src/anemoi/graphs/generate/icosahedral.py b/src/anemoi/graphs/generate/tri_icosahedron.py similarity index 64% rename from src/anemoi/graphs/generate/icosahedral.py rename to src/anemoi/graphs/generate/tri_icosahedron.py index 357676a..8feb780 100644 --- a/src/anemoi/graphs/generate/icosahedral.py +++ b/src/anemoi/graphs/generate/tri_icosahedron.py @@ -5,48 +5,54 @@ import networkx as nx import numpy as np import trimesh -from sklearn.metrics.pairwise import haversine_distances from sklearn.neighbors import BallTree +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad +from anemoi.graphs.generate.utils import get_coordinates_ordering -def create_icosahedral_nodes( - resolutions: list[int], +def create_tri_nodes( + resolution: int, aoi_mask_builder: KNNAreaMaskBuilder | None = None ) -> tuple[nx.DiGraph, np.ndarray, list[int]]: - """Creates a global mesh following AIFS strategy. + """Creates a global mesh from a refined icosahedron. This method relies on the trimesh python library. Parameters ---------- - resolutions : list[int] - Levels of mesh resolution to consider. + resolution : int + Level of mesh resolution to consider. + aoi_mask_builder : KNNAreaMaskBuilder + KNNAreaMaskBuilder with the cloud of points to limit the mesh area, by default None. Returns ------- graph : networkx.Graph - The specified graph (nodes & edges). + The specified graph (only nodes) sorted by latitude and longitude. coords_rad : np.ndarray The node coordinates (not ordered) in radians. node_ordering : list[int] - Order of the nodes in the graph to be sorted by latitude and longitude. + Order of the node coordinates to be sorted by latitude and longitude. """ - sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) + sphere = trimesh.creation.icosphere(subdivisions=resolution, radius=1.0) coords_rad = cartesian_to_latlon_rad(sphere.vertices) - node_ordering = get_node_ordering(coords_rad) + node_ordering = get_coordinates_ordering(coords_rad) - # TODO: AOI mask builder is not used in the current implementation. + if aoi_mask_builder is not None: + aoi_mask = aoi_mask_builder.get_mask(coords_rad) + node_ordering = node_ordering[aoi_mask[node_ordering]] - nx_graph = create_icosahedral_nx_graph_from_coords(coords_rad, node_ordering) + # Creates the graph, with the nodes sorted by latitude and longitude. + nx_graph = create_nx_graph_from_tri_coords(coords_rad, node_ordering) return nx_graph, coords_rad, list(node_ordering) -def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_ordering: list[int]): - +def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.ndarray) -> nx.DiGraph: + """Creates the networkx graph from the coordinates and the node ordering.""" graph = nx.DiGraph() for i, coords in enumerate(coords_rad[node_ordering]): node_id = node_ordering[i] @@ -57,19 +63,11 @@ def create_icosahedral_nx_graph_from_coords(coords_rad: np.ndarray, node_orderin return graph -def get_node_ordering(coords_rad: np.ndarray) -> np.ndarray: - """Get the node ordering to sort the nodes by latitude and longitude.""" - # Get indices to sort points by lon & lat in radians. - index_latitude = np.argsort(coords_rad[:, 1]) - index_longitude = np.argsort(coords_rad[index_latitude][:, 0])[::-1] - node_ordering = np.arange(coords_rad.shape[0])[index_latitude][index_longitude] - return node_ordering - - def add_edges_to_nx_graph( graph: nx.DiGraph, resolutions: list[int], x_hops: int = 1, + area_mask_builder: KNNAreaMaskBuilder | None = None, ) -> nx.DiGraph: """Adds the edges to the graph. @@ -84,6 +82,8 @@ def add_edges_to_nx_graph( Levels of mesh refinement levels to consider. x_hops : int, optional Number of hops between 2 nodes to consider them neighbours, by default 1. + area_mask_builder : KNNAreaMaskBuilder + NearestNeighbors with the cloud of points to limit the mesh area, by default None. Returns ------- @@ -94,10 +94,10 @@ def add_edges_to_nx_graph( sphere = trimesh.creation.icosphere(subdivisions=resolutions[-1], radius=1.0) vertices_rad = cartesian_to_latlon_rad(sphere.vertices) - node_neighbours = get_neighbours_within_hops(sphere, x_hops) + node_neighbours = get_neighbours_within_hops(sphere, x_hops, valid_nodes=list(graph.nodes)) for idx_node, idx_neighbours in node_neighbours.items(): - add_neigbours_edges(graph, vertices_rad, idx_node, idx_neighbours) + add_neigbours_edges(graph, idx_node, idx_neighbours) tree = BallTree(vertices_rad, metric="haversine") @@ -109,20 +109,25 @@ def add_edges_to_nx_graph( # Get the vertices of the isophere r_vertices_rad = cartesian_to_latlon_rad(r_sphere.vertices) - # TODO AOI mask builder is not used in the current implementation. + # Limit area of mesh points. + if area_mask_builder is not None: + aoi_mask = area_mask_builder.get_mask(vertices_rad) + valid_nodes = np.where(aoi_mask)[0] + else: + valid_nodes = None - node_neighbours = get_neighbours_within_hops(r_sphere, x_hops) + node_neighbours = get_neighbours_within_hops(r_sphere, x_hops, valid_nodes=valid_nodes) _, vertex_mapping_index = tree.query(r_vertices_rad, k=1) for idx_node, idx_neighbours in node_neighbours.items(): - add_neigbours_edges( - graph, r_vertices_rad, idx_node, idx_neighbours, vertex_mapping_index=vertex_mapping_index - ) + add_neigbours_edges(graph, idx_node, idx_neighbours, vertex_mapping_index=vertex_mapping_index) return graph -def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[int, set[int]]: +def get_neighbours_within_hops( + tri_mesh: trimesh.Trimesh, x_hops: int, valid_nodes: list[int] | None = None +) -> dict[int, set[int]]: """Get the neigbour connections in the graph. Parameters @@ -131,6 +136,8 @@ def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[i The mesh to consider. x_hops : int Number of hops between 2 nodes to consider them neighbours. + valid_nodes : list[int], optional + The list of valid nodes to consider, by default None. Returns ------- @@ -140,10 +147,12 @@ def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[i """ edges = tri_mesh.edges_unique - valid_nodes = list(range(len(tri_mesh.vertices))) + if valid_nodes is not None: + edges = edges[np.isin(tri_mesh.edges_unique, valid_nodes).all(axis=1)] + else: + valid_nodes = list(range(len(tri_mesh.vertices))) graph = nx.from_edgelist(edges) - # Get a dictionary of the neighbours within 'x_hops' neighbourhood of each node in the graph neighbours = { i: set(nx.ego_graph(graph, i, radius=x_hops, center=False) if i in graph else []) for i in valid_nodes } @@ -153,7 +162,6 @@ def get_neighbours_within_hops(tri_mesh: trimesh.Trimesh, x_hops: int) -> dict[i def add_neigbours_edges( graph: nx.Graph, - vertices: np.ndarray, node_idx: int, neighbour_indices: Iterable[int], self_loops: bool = False, @@ -165,8 +173,6 @@ def add_neigbours_edges( ---------- graph : nx.Graph The graph. - vertices : np.ndarray - A 2D array of shape (num_vertices, 2) with the planar coordinates of the mesh, in radians. node_idx : int The node considered. neighbour_indices : list[int] @@ -180,10 +186,6 @@ def add_neigbours_edges( if not self_loops and node_idx == neighbour_idx: # no self-loops continue - location_node = vertices[node_idx] - location_neighbour = vertices[neighbour_idx] - edge_length = haversine_distances([location_neighbour, location_node])[0][1] - if vertex_mapping_index is not None: # Use the same method to add edge in all spheres node_neighbour = vertex_mapping_index[neighbour_idx][0] @@ -191,6 +193,6 @@ def add_neigbours_edges( else: node, node_neighbour = node_idx, neighbour_idx - # add edge to the graph (if both source and target nodes are in the graph) + # add edge to the graph if node in graph and node_neighbour in graph: - graph.add_edge(node_neighbour, node, weight=edge_length) + graph.add_edge(node_neighbour, node) diff --git a/src/anemoi/graphs/generate/utils.py b/src/anemoi/graphs/generate/utils.py new file mode 100644 index 0000000..df72a2a --- /dev/null +++ b/src/anemoi/graphs/generate/utils.py @@ -0,0 +1,22 @@ +import numpy as np + + +def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: + """Sort node coordinates by latitude and longitude. + + Parameters + ---------- + coords : np.ndarray of shape (N, 2) + The node coordinates, with the latitude in the first column and the + longitude in the second column. + + Returns + ------- + np.ndarray + The order of the node coordinates to be sorted by latitude and longitude. + """ + # Get indices to sort points by lon & lat in radians. + index_latitude = np.argsort(coords[:, 1]) + index_longitude = np.argsort(coords[index_latitude][:, 0])[::-1] + node_ordering = np.arange(coords.shape[0])[index_latitude][index_longitude] + return node_ordering diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 7b6b149..b5c919a 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,6 +1,23 @@ -from .builder import HexNodes -from .builder import NPZFileNodes -from .builder import TriNodes -from .builder import ZarrDatasetNodes +from .builders.from_file import CutOutZarrDatasetNodes +from .builders.from_file import LimitedAreaNPZFileNodes +from .builders.from_file import NPZFileNodes +from .builders.from_file import ZarrDatasetNodes +from .builders.from_healpix import HEALPixNodes +from .builders.from_healpix import LimitedAreaHEALPixNodes +from .builders.from_refined_icosahedron import HexNodes +from .builders.from_refined_icosahedron import LimitedAreaHexNodes +from .builders.from_refined_icosahedron import LimitedAreaTriNodes +from .builders.from_refined_icosahedron import TriNodes -__all__ = ["ZarrDatasetNodes", "NPZFileNodes", "TriNodes", "HexNodes"] +__all__ = [ + "ZarrDatasetNodes", + "NPZFileNodes", + "TriNodes", + "HexNodes", + "HEALPixNodes", + "LimitedAreaHEALPixNodes", + "CutOutZarrDatasetNodes", + "LimitedAreaNPZFileNodes", + "LimitedAreaTriNodes", + "LimitedAreaHexNodes", +] diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py deleted file mode 100644 index 12c818c..0000000 --- a/src/anemoi/graphs/nodes/builder.py +++ /dev/null @@ -1,331 +0,0 @@ -from __future__ import annotations - -import logging -from abc import ABC -from abc import abstractmethod -from pathlib import Path - -import numpy as np -import torch -from anemoi.datasets import open_dataset -from anemoi.utils.config import DotDict -from hydra.utils import instantiate -from torch_geometric.data import HeteroData - -from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes -from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes - -LOGGER = logging.getLogger(__name__) - - -class BaseNodeBuilder(ABC): - """Base class for node builders. - - The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. - - Attributes - ---------- - name : str - name of the nodes, key for the nodes in the HeteroData graph object. - """ - - def __init__(self, name: str) -> None: - self.name = name - - def register_nodes(self, graph: HeteroData) -> None: - """Register nodes in the graph. - - Parameters - ---------- - graph : HeteroData - The graph to register the nodes. - """ - graph[self.name].x = self.get_coordinates() - graph[self.name].node_type = type(self).__name__ - return graph - - def register_attributes(self, graph: HeteroData, config: DotDict = None) -> HeteroData: - """Register attributes in the nodes of the graph specified. - - Parameters - ---------- - graph : HeteroData - The graph to register the attributes. - config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with the registered attributes. - """ - for attr_name, attr_config in config.items(): - graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) - return graph - - @abstractmethod - def get_coordinates(self) -> torch.Tensor: ... - - def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: - """Reshape latitude and longitude coordinates. - - Parameters - ---------- - latitudes : np.ndarray of shape (N, ) - Latitude coordinates, in degrees. - longitudes : np.ndarray of shape (N, ) - Longitude coordinates, in degrees. - - Returns - ------- - torch.Tensor of shape (N, 2) - A 2D tensor with the coordinates, in radians. - """ - coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) - coords = np.deg2rad(coords) - return torch.tensor(coords, dtype=torch.float32) - - def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData: - """Update the graph with new nodes. - - Parameters - ---------- - graph : HeteroData - Input graph. - attr_config : DotDict - The configuration of the attributes. - - Returns - ------- - HeteroData - The graph with new nodes included. - """ - graph = self.register_nodes(graph) - - if attr_config is None: - return graph - - graph = self.register_attributes(graph, attr_config) - - return graph - - -class ZarrDatasetNodes(BaseNodeBuilder): - """Nodes from Zarr dataset. - - Attributes - ---------- - dataset : zarr.core.Array - The dataset. - - Methods - ------- - register_nodes(graph) - Register the nodes in the graph. - register_attributes(graph, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, dataset: DotDict, name: str) -> None: - LOGGER.info("Reading the dataset from %s.", dataset) - self.dataset = open_dataset(dataset) - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. - """ - return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) - - -class NPZFileNodes(BaseNodeBuilder): - """Nodes from NPZ defined grids. - - Attributes - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - grid_definition : dict[str, np.ndarray] - The grid definition. - - Methods - ------- - register_nodes(graph) - Register the nodes in the graph. - register_attributes(graph, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: - """Initialize the NPZFileNodes builder. - - The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. - - Parameters - ---------- - resolution : str - The resolution of the grid. - grid_definition_path : str - Path to the folder containing the grid definition files. - """ - self.resolution = resolution - self.grid_definition_path = grid_definition_path - self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. - """ - coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) - return coords - - -class IcosahedralNodes(BaseNodeBuilder, ABC): - """Processor mesh based on a triangular mesh. - - It is based on the icosahedral mesh, which is a mesh of triangles that covers the sphere. - - Parameters - ---------- - resolution : list[int] | int - Refinement level of the mesh. - name : str - The name of the nodes. - """ - - def __init__( - self, - resolution: int | list[int], - name: str, - ) -> None: - self.resolutions = list(range(resolution + 1)) if isinstance(resolution, int) else resolution - super().__init__(name) - - def get_coordinates(self) -> torch.Tensor: - self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) - - @abstractmethod - def create_nodes(self) -> np.ndarray: ... - - def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: - graph[self.name]["_resolutions"] = self.resolutions - graph[self.name]["_nx_graph"] = self.nx_graph - graph[self.name]["_node_ordering"] = self.node_ordering - return super().register_attributes(graph, config) - - -class TriNodes(IcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron. - - It depends on the trimesh Python library. - - Attributes - ---------- - resolutions : list[int] - Refinement level of the mesh. - name : str - The name of the nodes. - - Methods - ------- - register_nodes(graph) - Register the nodes in the graph. - register_attributes(graph, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, attr_config) - Update the graph with new nodes and attributes. - """ - - def create_nodes(self) -> np.ndarray: - return create_icosahedral_nodes(resolutions=self.resolutions) - - -class HexNodes(IcosahedralNodes): - """Nodes based on iterative refinements of an icosahedron. - - It depends on the h3 Python library. - - Attributes - ---------- - resolutions : list[int] - Refinement level of the mesh. - name : str - The name of the nodes. - - Methods - ------- - register_nodes(graph) - Register the nodes in the graph. - register_attributes(graph, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, attr_config) - Update the graph with new nodes and attributes. - """ - - def create_nodes(self) -> np.ndarray: - return create_hexagonal_nodes(self.resolutions) - - -class HEALPixNodes(BaseNodeBuilder): - """Nodes from HEALPix grid. - - HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere. - - Attributes - ---------- - resolution : int - The resolution of the grid. - name : str - The name of the nodes. - - Methods - ------- - register_nodes(graph, name) - Register the nodes in the graph. - register_attributes(graph, name, config) - Register the attributes in the nodes of the graph specified. - update_graph(graph, name, attr_config) - Update the graph with new nodes and attributes. - """ - - def __init__(self, resolution: int, name: str) -> None: - """Initialize the HEALPixNodes builder.""" - self.resolution = resolution - super().__init__(name) - - assert isinstance(resolution, int), "Resolution must be an integer." - assert resolution > 0, "Resolution must be positive." - - def get_coordinates(self) -> torch.Tensor: - """Get the coordinates of the nodes. - - Returns - ------- - torch.Tensor of shape (N, 2) - Coordinates of the nodes. - """ - import healpy as hp - - spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60 - LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.") - - npix = hp.nside2npix(2**self.resolution) - hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) - - return self.reshape_coords(hpxlat, hpxlon) diff --git a/src/anemoi/graphs/nodes/builders/__init__.py b/src/anemoi/graphs/nodes/builders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/anemoi/graphs/nodes/builders/base.py b/src/anemoi/graphs/nodes/builders/base.py new file mode 100644 index 0000000..b0e670f --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/base.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import numpy as np +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + + +class BaseNodeBuilder(ABC): + """Base class for node builders. + + The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + + Attributes + ---------- + name : str + name of the nodes, key for the nodes in the HeteroData graph object. + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder, if any. Defaults to None. + """ + + def __init__(self, name: str) -> None: + self.name = name + self.aoi_mask_builder = None + + def register_nodes(self, graph: HeteroData) -> None: + """Register nodes in the graph. + + Parameters + ---------- + graph : HeteroData + The graph to register the nodes. + """ + graph[self.name].x = self.get_coordinates() + graph[self.name].node_type = type(self).__name__ + return graph + + def register_attributes(self, graph: HeteroData, config: DotDict | None = None) -> HeteroData: + """Register attributes in the nodes of the graph specified. + + Parameters + ---------- + graph : HeteroData + The graph to register the attributes. + config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with the registered attributes. + """ + for attr_name, attr_config in config.items(): + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) + + return graph + + @abstractmethod + def get_coordinates(self) -> torch.Tensor: ... + + def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch.Tensor: + """Reshape latitude and longitude coordinates. + + Parameters + ---------- + latitudes : np.ndarray of shape (num_nodes, ) + Latitude coordinates, in degrees. + longitudes : np.ndarray of shape (num_nodes, ) + Longitude coordinates, in degrees. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2)) + coords = np.deg2rad(coords) + return torch.tensor(coords, dtype=torch.float32) + + def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData: + """Update the graph with new nodes. + + Parameters + ---------- + graph : HeteroData + Input graph. + attr_config : DotDict + The configuration of the attributes. + + Returns + ------- + HeteroData + The graph with new nodes included. + """ + graph = self.register_nodes(graph) + + if attr_config is None: + return graph + + graph = self.register_attributes(graph, attr_config) + + return graph diff --git a/src/anemoi/graphs/nodes/builders/from_file.py b/src/anemoi/graphs/nodes/builders/from_file.py new file mode 100644 index 0000000..7fa7661 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_file.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +import torch +from anemoi.datasets import open_dataset +from anemoi.utils.config import DotDict +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class ZarrDatasetNodes(BaseNodeBuilder): + """Nodes from Zarr dataset. + + Attributes + ---------- + dataset : zarr.core.Array + The dataset. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, dataset: DotDict, name: str) -> None: + LOGGER.info("Reading the dataset from %s.", dataset) + self.dataset = open_dataset(dataset) + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + return self.reshape_coords(self.dataset.latitudes, self.dataset.longitudes) + + +class CutOutZarrDatasetNodes(ZarrDatasetNodes): + """Nodes from Zarr dataset.""" + + def __init__( + self, name: str, lam_dataset: str, forcing_dataset: str, thinning: int = 1, adjust: str = "all" + ) -> None: + dataset_config = { + "cutout": [{"dataset": lam_dataset, "thinning": thinning}, {"dataset": forcing_dataset}], + "adjust": adjust, + } + super().__init__(dataset_config, name) + self.n_cutout, self.n_other = self.dataset.grids + + def register_attributes(self, graph: HeteroData, config: DotDict) -> None: + # this is a mask to cutout the LAM area + graph[self.name]["cutout"] = torch.tensor([True] * self.n_cutout + [False] * self.n_other, dtype=bool).reshape( + (-1, 1) + ) + return super().register_attributes(graph, config) + + +class NPZFileNodes(BaseNodeBuilder): + """Nodes from NPZ defined grids. + + Attributes + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + grid_definition : dict[str, np.ndarray] + The grid definition. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: + """Initialize the NPZFileNodes builder. + + The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. + + Parameters + ---------- + resolution : str + The resolution of the grid. + grid_definition_path : str + Path to the folder containing the grid definition files. + """ + self.resolution = resolution + self.grid_definition_path = grid_definition_path + self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + coords = self.reshape_coords(self.grid_definition["latitudes"], self.grid_definition["longitudes"]) + return coords + + +class LimitedAreaNPZFileNodes(NPZFileNodes): + """Nodes from NPZ defined grids using an area of interest.""" + + def __init__( + self, + resolution: str, + grid_definition_path: str, + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, grid_definition_path, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + "Limiting the processor mesh to a radius of %.2f km from the output mesh.", + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(coords) + + LOGGER.info("Dropping %d nodes from the processor mesh.", len(aoi_mask) - aoi_mask.sum()) + coords = coords[aoi_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_healpix.py b/src/anemoi/graphs/nodes/builders/from_healpix.py new file mode 100644 index 0000000..a8ef080 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_healpix.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging + +import numpy as np +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class HEALPixNodes(BaseNodeBuilder): + """Nodes from HEALPix grid. + + HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere. + + Attributes + ---------- + resolution : int + The resolution of the grid. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. + """ + + def __init__(self, resolution: int, name: str) -> None: + """Initialize the HEALPixNodes builder.""" + self.resolution = resolution + super().__init__(name) + + assert isinstance(resolution, int), "Resolution must be an integer." + assert resolution > 0, "Resolution must be positive." + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + Coordinates of the nodes, in radians. + """ + import healpy as hp + + spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60 + LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.") + + npix = hp.nside2npix(2**self.resolution) + hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True) + + return self.reshape_coords(hpxlat, hpxlon) + + +class LimitedAreaHEALPixNodes(HEALPixNodes): + """Nodes from HEALPix grid using an area of interest.""" + + def __init__( + self, + resolution: str, + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + super().__init__(resolution, name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + def get_coordinates(self) -> np.ndarray: + coords = super().get_coordinates() + + LOGGER.info( + 'Limiting the "%s" nodes to a radius of %.2f km from the nodes of interest.', + self.name, + self.aoi_mask_builder.margin_radius_km, + ) + aoi_mask = self.aoi_mask_builder.get_mask(coords) + + LOGGER.info('Masking out %d nodes from "%s".', len(aoi_mask) - aoi_mask.sum(), self.name) + coords = coords[aoi_mask] + + return coords diff --git a/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py new file mode 100644 index 0000000..830d910 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_refined_icosahedron.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod + +import networkx as nx +import numpy as np +import torch +from anemoi.utils.config import DotDict +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.hex_icosahedron import create_hex_nodes +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder +from anemoi.graphs.generate.tri_icosahedron import create_tri_nodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + +LOGGER = logging.getLogger(__name__) + + +class IcosahedralNodes(BaseNodeBuilder, ABC): + """Nodes based on iterative refinements of an icosahedron. + + Attributes + ---------- + resolution : list[int] | int + Refinement level of the mesh. + name : str + Name of the nodes. + """ + + def __init__( + self, + resolution: int | list[int], + name: str, + ) -> None: + if isinstance(resolution, int): + self.resolutions = list(range(resolution + 1)) + else: + self.resolutions = resolution + + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + """Get the coordinates of the nodes. + + Returns + ------- + torch.Tensor of shape (num_nodes, 2) + A 2D tensor with the coordinates, in radians. + """ + self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() + return torch.tensor(coords_rad[self.node_ordering], dtype=torch.float32) + + @abstractmethod + def create_nodes(self) -> tuple[nx.DiGraph, np.ndarray, list[int]]: ... + + def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: + graph[self.name]["_resolutions"] = self.resolutions + graph[self.name]["_nx_graph"] = self.nx_graph + graph[self.name]["_node_ordering"] = self.node_ordering + graph[self.name]["_aoi_mask_builder"] = self.aoi_mask_builder + return super().register_attributes(graph, config) + + +class LimitedAreaIcosahedralNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + Attributes + ---------- + aoi_mask_builder : KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def __init__( + self, + resolution: int | list[int], + reference_node_name: str, + name: str, + mask_attr_name: str | None = None, + margin_radius_km: float = 100.0, + ) -> None: + + super().__init__(resolution, name) + + self.aoi_mask_builder = KNNAreaMaskBuilder(reference_node_name, margin_radius_km, mask_attr_name) + + def register_nodes(self, graph: HeteroData) -> None: + self.aoi_mask_builder.fit(graph) + return super().register_nodes(graph) + + +class TriNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the trimesh Python library. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_tri_nodes(resolution=max(self.resolutions)) + + +class HexNodes(IcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron. + + It depends on the h3 Python library. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_hex_nodes(resolution=max(self.resolutions)) + + +class LimitedAreaTriNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the trimesh Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_tri_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) + + +class LimitedAreaHexNodes(LimitedAreaIcosahedralNodes): + """Nodes based on iterative refinements of an icosahedron using an area of interest. + + It depends on the h3 Python library. + + Parameters + ---------- + aoi_mask_builder: KNNAreaMaskBuilder + The area of interest mask builder. + """ + + def create_nodes(self) -> tuple[nx.Graph, np.ndarray, list[int]]: + return create_hex_nodes(resolution=max(self.resolutions), aoi_mask_builder=self.aoi_mask_builder) diff --git a/src/anemoi/graphs/nodes/weights.py b/src/anemoi/graphs/nodes/weights.py deleted file mode 100644 index 25419cc..0000000 --- a/src/anemoi/graphs/nodes/weights.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from abc import ABC -from abc import abstractmethod -from typing import Optional - -import numpy as np -import torch -from scipy.spatial import SphericalVoronoi -from torch_geometric.data.storage import NodeStorage - -from anemoi.graphs.generate.transforms import to_sphere_xyz -from anemoi.graphs.normalizer import NormalizerMixin - -logger = logging.getLogger(__name__) - - -class BaseWeights(ABC, NormalizerMixin): - """Base class for the weights of the nodes.""" - - def __init__(self, norm: Optional[str] = None): - self.norm = norm - - @abstractmethod - def compute(self, nodes: NodeStorage, *args, **kwargs): ... - - def get_weights(self, *args, **kwargs) -> torch.Tensor: - weights = self.compute(*args, **kwargs) - if weights.ndim == 1: - weights = weights[:, np.newaxis] - norm_weights = self.normalize(weights) - return torch.tensor(norm_weights, dtype=torch.float32) - - -class UniformWeights(BaseWeights): - """Implements a uniform weight for the nodes.""" - - def compute(self, nodes: NodeStorage) -> np.ndarray: - return np.ones(nodes.num_nodes) - - -class AreaWeights(BaseWeights): - """Implements the area of the nodes as the weights.""" - - def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])): - super().__init__(norm=norm) - - # Weighting of the nodes - self.radius = radius - self.centre = centre - - def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: - latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1] - points = to_sphere_xyz((latitudes, longitudes)) - sv = SphericalVoronoi(points, self.radius, self.centre) - area_weights = sv.calculate_areas() - logger.debug( - "There are %d of weights, which (unscaled) add up a total weight of %.2f.", - len(area_weights), - np.array(area_weights).sum(), - ) - return area_weights diff --git a/src/anemoi/graphs/plotting/displots.py b/src/anemoi/graphs/plotting/displots.py index 22316f5..5fe37ee 100644 --- a/src/anemoi/graphs/plotting/displots.py +++ b/src/anemoi/graphs/plotting/displots.py @@ -5,6 +5,7 @@ from typing import Union import matplotlib.pyplot as plt +import numpy as np import torch from torch_geometric.data import HeteroData from torch_geometric.data.storage import EdgeStorage @@ -83,7 +84,9 @@ def plot_distribution_attributes( # Define the layout _, axs = plt.subplots(num_items, dim_attrs, figsize=(10 * num_items, 10)) - if axs.ndim == 1: + if num_items == dim_attrs == 1: + axs = np.array([[axs]]) + elif axs.ndim == 1: axs = axs.reshape(num_items, dim_attrs) for i, (item_name, item_store) in enumerate(graph_items): diff --git a/tests/conftest.py b/tests/conftest.py index b801614..23208c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,10 +11,12 @@ class MockZarrDataset: """Mock Zarr dataset with latitudes and longitudes attributes.""" - def __init__(self, latitudes, longitudes): + def __init__(self, latitudes, longitudes, grids=None): self.latitudes = latitudes self.longitudes = longitudes self.num_nodes = len(latitudes) + if grids is not None: + self.grids = grids @pytest.fixture @@ -24,6 +26,14 @@ def mock_zarr_dataset() -> MockZarrDataset: return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1]) +@pytest.fixture +def mock_zarr_dataset_cutout() -> MockZarrDataset: + """Mock zarr dataset with nodes.""" + coords = 2 * torch.pi * np.array([[lat, lon] for lat in lats for lon in lons]) + grids = int(0.3 * len(coords)), int(0.7 * len(coords)) + return MockZarrDataset(latitudes=coords[:, 0], longitudes=coords[:, 1], grids=grids) + + @pytest.fixture def mock_grids_path(tmp_path) -> tuple[str, int]: """Mock grid_definition_path with files for 3 resolutions.""" @@ -40,6 +50,7 @@ def graph_with_nodes() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) return graph @@ -49,6 +60,7 @@ def graph_nodes_and_edges() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph diff --git a/tests/generate/test_masks.py b/tests/generate/test_masks.py new file mode 100644 index 0000000..651bdb7 --- /dev/null +++ b/tests/generate/test_masks.py @@ -0,0 +1,48 @@ +import pytest +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder + + +def test_init(): + """Test KNNAreaMaskBuilder initialization.""" + mask_builder1 = KNNAreaMaskBuilder("nodes") + mask_builder2 = KNNAreaMaskBuilder("nodes", margin_radius_km=120) + mask_builder3 = KNNAreaMaskBuilder("nodes", mask_attr_name="mask") + mask_builder4 = KNNAreaMaskBuilder("nodes", margin_radius_km=120, mask_attr_name="mask") + + assert isinstance(mask_builder1, KNNAreaMaskBuilder) + assert isinstance(mask_builder2, KNNAreaMaskBuilder) + assert isinstance(mask_builder3, KNNAreaMaskBuilder) + assert isinstance(mask_builder4, KNNAreaMaskBuilder) + + assert isinstance(mask_builder1.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder2.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder3.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder4.nearest_neighbour, NearestNeighbors) + + +@pytest.mark.parametrize("margin", [-1, "120", None]) +def test_fail_init_wrong_margin(margin: int): + """Test KNNAreaMaskBuilder initialization with invalid margin.""" + with pytest.raises(AssertionError): + KNNAreaMaskBuilder("nodes", margin_radius_km=margin) + + +@pytest.mark.parametrize("mask", [None, "mask"]) +def test_fit(graph_with_nodes: HeteroData, mask: str): + """Test KNNAreaMaskBuilder fit.""" + mask_builder = KNNAreaMaskBuilder("test_nodes", mask_attr_name=mask) + assert not hasattr(mask_builder.nearest_neighbour, "n_samples_fit_") + + mask_builder.fit(graph_with_nodes) + + assert mask_builder.nearest_neighbour.n_samples_fit_ == graph_with_nodes["test_nodes"].num_nodes + + +def test_fit_fail(graph_with_nodes): + """Test KNNAreaMaskBuilder fit with wrong graph.""" + mask_builder = KNNAreaMaskBuilder("wrong_nodes") + with pytest.raises(AssertionError): + mask_builder.fit(graph_with_nodes) diff --git a/tests/nodes/test_cutout_nodes.py b/tests/nodes/test_cutout_nodes.py new file mode 100644 index 0000000..a18424d --- /dev/null +++ b/tests/nodes/test_cutout_nodes.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.nodes.attributes import AreaWeights +from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file + + +def test_init(mocker, mock_zarr_dataset_cutout): + """Test CutOutZarrDatasetNodes initialization.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.CutOutZarrDatasetNodes) + + +def test_fail_init(): + """Test CutOutZarrDatasetNodes initialization with invalid resolution.""" + with pytest.raises(TypeError): + from_file.CutOutZarrDatasetNodes("global_dataset.zarr", name="test_nodes") + + +def test_register_nodes(mocker, mock_zarr_dataset_cutout): + """Test CutOutZarrDatasetNodes register correctly the nodes.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + graph = HeteroData() + + graph = node_builder.register_nodes(graph) + + assert graph["test_nodes"].x is not None + assert isinstance(graph["test_nodes"].x, torch.Tensor) + assert graph["test_nodes"].x.shape == (node_builder.dataset.num_nodes, 2) + assert graph["test_nodes"].node_type == "CutOutZarrDatasetNodes" + + +@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) +def test_register_attributes(mocker, mock_zarr_dataset_cutout, graph_with_nodes: HeteroData, attr_class): + """Test CutOutZarrDatasetNodes register correctly the weights.""" + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset_cutout) + node_builder = from_file.CutOutZarrDatasetNodes( + forcing_dataset="global.zarr", lam_dataset="lam.zarr", name="test_nodes" + ) + config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} + + graph = node_builder.register_attributes(graph_with_nodes, config) + + assert graph["test_nodes"]["test_attr"] is not None + assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) + assert graph["test_nodes"]["test_attr"].shape[0] == graph["test_nodes"].x.shape[0] diff --git a/tests/nodes/test_healpix.py b/tests/nodes/test_healpix.py index 3c6883c..3293c1c 100644 --- a/tests/nodes/test_healpix.py +++ b/tests/nodes/test_healpix.py @@ -4,8 +4,8 @@ from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights -from anemoi.graphs.nodes.builder import BaseNodeBuilder -from anemoi.graphs.nodes.builder import HEALPixNodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder +from anemoi.graphs.nodes.builders.from_healpix import HEALPixNodes @pytest.mark.parametrize("resolution", [2, 5, 7]) diff --git a/tests/nodes/test_hex_nodes.py b/tests/nodes/test_hex_nodes.py index 03e00da..54c2b6b 100644 --- a/tests/nodes/test_hex_nodes.py +++ b/tests/nodes/test_hex_nodes.py @@ -3,7 +3,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.nodes import HexNodes -from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder @pytest.mark.parametrize("resolution", [0, 2]) diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 95d09c0..21b767a 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -4,7 +4,7 @@ from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights -from anemoi.graphs.nodes.builder import NPZFileNodes +from anemoi.graphs.nodes.builders.from_file import NPZFileNodes @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) diff --git a/tests/nodes/test_tri_nodes.py b/tests/nodes/test_tri_nodes.py index af1af69..4f522ce 100644 --- a/tests/nodes/test_tri_nodes.py +++ b/tests/nodes/test_tri_nodes.py @@ -3,7 +3,7 @@ from torch_geometric.data import HeteroData from anemoi.graphs.nodes import TriNodes -from anemoi.graphs.nodes.builder import BaseNodeBuilder +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder @pytest.mark.parametrize("resolution", [0, 2]) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e7c98cc..90610ac 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -3,30 +3,30 @@ import zarr from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import builder from anemoi.graphs.nodes.attributes import AreaWeights from anemoi.graphs.nodes.attributes import UniformWeights +from anemoi.graphs.nodes.builders import from_file def test_init(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes initialization.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") - assert isinstance(node_builder, builder.BaseNodeBuilder) - assert isinstance(node_builder, builder.ZarrDatasetNodes) + assert isinstance(node_builder, from_file.BaseNodeBuilder) + assert isinstance(node_builder, from_file.ZarrDatasetNodes) def test_fail_init(): """Test ZarrDatasetNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") + from_file.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrDatasetNodes register correctly the nodes.""" - mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=mock_zarr_dataset) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") graph = HeteroData() graph = node_builder.register_nodes(graph) @@ -40,8 +40,8 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrDatasetNodes register correctly the weights.""" - mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + mocker.patch.object(from_file, "open_dataset", return_value=None) + node_builder = from_file.ZarrDatasetNodes("dataset.zarr", name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, config)