Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local Area graphs #30

Merged
merged 168 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 167 commits
Commits
Show all changes
168 commits
Select commit Hold shift + click to select a range
38f8d15
feat: Initial implementation of global graphs
theissenhelen Jun 24, 2024
9dc2cec
add dependencies
theissenhelen Jun 25, 2024
f1fe18f
add cli command
JPXKQX Jun 25, 2024
b8b558d
Ignore .pt files
JPXKQX Jun 25, 2024
7f6f4bd
run pre-commit
JPXKQX Jun 25, 2024
d5f67fd
docstring + log erros
JPXKQX Jun 26, 2024
b12272d
initial tests
JPXKQX Jun 26, 2024
cce5ea6
feat: initial version of AttributeBuilder
theissenhelen Jun 26, 2024
9ba0391
refactor: separate into node edge attribute builders
theissenhelen Jun 26, 2024
9a47184
feat: edge_length moved to edges/attributes.py
JPXKQX Jun 27, 2024
384adc7
remove __init__
JPXKQX Jun 27, 2024
0bc176c
feat: test edge builders
JPXKQX Jun 27, 2024
d16934b
add blank lines
JPXKQX Jun 27, 2024
0f82ea7
dep: hydra-core
JPXKQX Jun 27, 2024
a9c5ada
bugfix (encoder edge lengths) + refector
JPXKQX Jun 27, 2024
66ef5dc
deps: == to >=
JPXKQX Jun 27, 2024
fcd0b74
rename node builder classes
JPXKQX Jun 27, 2024
b28b0ff
fix: tests
JPXKQX Jun 27, 2024
ec8c9c5
feat: support path and dict for `config` argument
JPXKQX Jun 28, 2024
9b9d805
fix: error
JPXKQX Jun 28, 2024
2b67bf3
refactor: naming
theissenhelen Jul 1, 2024
cdeaa03
fix: pre-commit
theissenhelen Jul 1, 2024
f07434c
feat: builders icosahedral
theissenhelen Jun 28, 2024
52403d7
feat: Add icosahedral graph generation
theissenhelen Jun 28, 2024
2ef63c2
refactor: remove create_shere
theissenhelen Jun 28, 2024
1fe76bb
feat: Icosahedral edge builder
theissenhelen Jun 28, 2024
fde0fe6
feat: hexagonal graph generation
theissenhelen Jun 28, 2024
97ef0da
feat: hexagonal builders
theissenhelen Jun 28, 2024
8dcda40
fix: AOI not implemented yet
theissenhelen Jun 28, 2024
63be6af
fix: abstractmethod and renaming
theissenhelen Jul 1, 2024
b175585
chore: add dependencies
theissenhelen Jul 1, 2024
86c5e35
test: add tests for trimesh
theissenhelen Jul 1, 2024
19461a1
test: add tests for hex (h3)
theissenhelen Jul 1, 2024
39ee3ad
fix: imports
theissenhelen Jul 1, 2024
f00fd72
fix: output type
theissenhelen Jul 1, 2024
75a82c8
refactor: delete unused file
theissenhelen Jul 1, 2024
f45b900
refactor: renaming and positioning
theissenhelen Jul 1, 2024
9f2c052
feat: ensure src and dst always the same
theissenhelen Jul 1, 2024
e410bf5
fix: imports
theissenhelen Jul 1, 2024
ef1c110
fix: edge_name not supported
theissenhelen Jul 1, 2024
2e6830f
test: add tests for TriIcosahedralEdges
theissenhelen Jul 1, 2024
a59f5d1
fix: assert missing for Hexagonal edges
theissenhelen Jul 2, 2024
59bac56
test: hexagonal edges
theissenhelen Jul 2, 2024
bd729c9
fix: avoid same name
theissenhelen Jul 2, 2024
9cce37a
feat: LimitedAreaZarrNodes
theissenhelen Jul 3, 2024
745709f
feat: add KNNMaskBuilder for use with LAM
theissenhelen Jul 3, 2024
bc735cd
feat: add KNNMaskBuilder to TriIcosahedral
theissenhelen Jul 3, 2024
03fbf9f
feat: AreaNPZFileNodes
theissenhelen Jul 3, 2024
3731e82
fix: KNNAreaMaskBuilder working with NPZ
theissenhelen Jul 3, 2024
ed64a7e
fix: imports and naming
theissenhelen Jul 3, 2024
1e2c37a
fix: TriIocsahedral working for area masks
theissenhelen Jul 5, 2024
c07c583
feat: debugging purposes
theissenhelen Jul 5, 2024
980ed8d
refactor: rename tests
theissenhelen Jul 5, 2024
3609681
Global Encoder-Processor-Decoder graph (#9)
JPXKQX Jul 5, 2024
a973c2d
fix: attributes as torch.float32
JPXKQX Jul 6, 2024
af111a6
new test: attributes must be float32
JPXKQX Jul 6, 2024
a8a1620
fix typo
JPXKQX Jul 6, 2024
926c75b
Homogeneize base builders
JPXKQX Jul 6, 2024
f342073
improve test docstrings
JPXKQX Jul 6, 2024
9d9fea8
homogeneize (name as class attribute)
JPXKQX Jul 6, 2024
1b20845
new input config
JPXKQX Jul 8, 2024
4e62431
new default
JPXKQX Jul 8, 2024
6825850
Merge branch 'develop' into fix/global-version
JPXKQX Jul 8, 2024
d25c47b
feat: Initial implementation of global graphs
theissenhelen Jun 24, 2024
8f0415e
add cli command
JPXKQX Jun 25, 2024
22fba0e
Ignore .pt files
JPXKQX Jun 25, 2024
2879c7c
run pre-commit
JPXKQX Jun 25, 2024
1ca0633
docstring + log erros
JPXKQX Jun 26, 2024
277c231
initial tests
JPXKQX Jun 26, 2024
7a4e9cb
feat: initial version of AttributeBuilder
theissenhelen Jun 26, 2024
96a6ed5
refactor: separate into node edge attribute builders
theissenhelen Jun 26, 2024
1c88ee8
feat: edge_length moved to edges/attributes.py
JPXKQX Jun 27, 2024
92988d7
remove __init__
JPXKQX Jun 27, 2024
de07488
bugfix (encoder edge lengths) + refector
JPXKQX Jun 27, 2024
c900bfd
feat: support path and dict for `config` argument
JPXKQX Jun 28, 2024
9768204
fix: error
JPXKQX Jun 28, 2024
74ef4cc
refactor: naming
theissenhelen Jul 1, 2024
91e15b3
fix: pre-commit
theissenhelen Jul 1, 2024
66d3aa1
feat: builders icosahedral
theissenhelen Jun 28, 2024
23b3110
feat: Add icosahedral graph generation
theissenhelen Jun 28, 2024
469b56b
refactor: remove create_shere
theissenhelen Jun 28, 2024
4e864c3
feat: Icosahedral edge builder
theissenhelen Jun 28, 2024
d573905
feat: hexagonal graph generation
theissenhelen Jun 28, 2024
edd9c43
feat: hexagonal builders
theissenhelen Jun 28, 2024
b7002e3
fix: AOI not implemented yet
theissenhelen Jun 28, 2024
1d55ef6
fix: abstractmethod and renaming
theissenhelen Jul 1, 2024
7dcfb98
chore: add dependencies
theissenhelen Jul 1, 2024
a35b76f
test: add tests for trimesh
theissenhelen Jul 1, 2024
bfc3e84
test: add tests for hex (h3)
theissenhelen Jul 1, 2024
fc2f707
fix: imports
theissenhelen Jul 1, 2024
7c3dca3
fix: output type
theissenhelen Jul 1, 2024
d627e58
refactor: delete unused file
theissenhelen Jul 1, 2024
735fbc1
refactor: renaming and positioning
theissenhelen Jul 1, 2024
46a2c07
feat: ensure src and dst always the same
theissenhelen Jul 1, 2024
c9e4fde
fix: imports
theissenhelen Jul 1, 2024
7a8b316
fix: edge_name not supported
theissenhelen Jul 1, 2024
02496ec
test: add tests for TriIcosahedralEdges
theissenhelen Jul 1, 2024
59fcad3
fix: assert missing for Hexagonal edges
theissenhelen Jul 2, 2024
b41f88e
test: hexagonal edges
theissenhelen Jul 2, 2024
5a43185
fix: avoid same name
theissenhelen Jul 2, 2024
a0259f8
fix: imports
theissenhelen Jul 8, 2024
7289e32
fix: conflicts
theissenhelen Jul 8, 2024
9a6decc
Merge branch 'develop' into 6-generate-graphs-from-icosahedral-meshes
theissenhelen Jul 8, 2024
4ca717b
update tests
JPXKQX Jul 9, 2024
fb93b8b
Merge branch 'develop' into 6-generate-graphs-from-icosahedral-meshes
theissenhelen Jul 9, 2024
e44ae0d
Merge branch '6-generate-graphs-from-icosahedral-meshes' into 7-local…
theissenhelen Jul 9, 2024
fe8a8e5
Include xhops to hexagonal edges
JPXKQX Jul 9, 2024
5865b35
Merge pull request #14 from ecmwf/hotfix/config_node_name
JPXKQX Jul 9, 2024
633091f
Merge remote-tracking branch 'origin/develop' into 6-generate-graphs-…
theissenhelen Jul 9, 2024
b2820e2
Merge remote-tracking branch 'origin/develop' into 6-generate-graphs-…
theissenhelen Jul 9, 2024
463911c
docs: update docstrings
theissenhelen Jul 9, 2024
ae5a8a7
Merge remote-tracking branch 'origin/develop' into 7-local-area-model…
theissenhelen Jul 9, 2024
b0a35b8
fix: update attribute name
JPXKQX Jul 10, 2024
4f445a9
refactor: rename multiscale nodes
theissenhelen Jul 11, 2024
fa812eb
refactor: rename icosahedral nodes
theissenhelen Jul 11, 2024
c7bd45c
Merge remote-tracking branch 'origin/6-generate-graphs-from-icosahedr…
theissenhelen Jul 11, 2024
d190758
refactor: LimitedArea prefix
theissenhelen Jul 11, 2024
dca3729
feat: add aoi_mask_builder to edge builder
theissenhelen Jul 11, 2024
86a5ce5
Merge remote-tracking branch 'origin/develop' into 7-local-area-model…
JPXKQX Jul 26, 2024
e73c203
Merge branch 'develop' into 7-local-area-modelling-graphs
theissenhelen Aug 5, 2024
3b26473
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Aug 6, 2024
9d015a3
docs & default values
JPXKQX Aug 6, 2024
41bd3f5
create LimitedAreaHEALPixNodes
JPXKQX Aug 6, 2024
030168c
fix: import HEALPixNodes
JPXKQX Aug 6, 2024
9493933
fix: avoid runtimeError when deleting a key
JPXKQX Aug 6, 2024
986efe7
fix: set config arg to pathlib.Path
JPXKQX Aug 6, 2024
c9c736d
more logging
JPXKQX Aug 6, 2024
6e571b3
create LimitedAreaIcoshaderalNodes
JPXKQX Aug 6, 2024
81a3419
refactor LimiteAreaIcosahedralNodes class
JPXKQX Aug 6, 2024
8460e06
types & docstrings
JPXKQX Aug 6, 2024
c35c4ff
fix(lam): icosahedral nodes in lam
JPXKQX Aug 7, 2024
60118dd
fix: style
JPXKQX Aug 7, 2024
6c0585e
feat: icosahedral & hexagonal lam multiscale edges for lam
JPXKQX Aug 7, 2024
be72653
fixs(docs): typo
JPXKQX Aug 8, 2024
dcbd4ed
fix: remove redundant code
JPXKQX Aug 9, 2024
e2cb2b4
refactor: remove edge attr computation during graph creation
JPXKQX Aug 9, 2024
74533d4
refactor: split builder.py in several files
JPXKQX Aug 9, 2024
045ab09
fix: rename node builder class
JPXKQX Aug 9, 2024
2137f10
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Aug 9, 2024
18dc2a3
fix: test imports
JPXKQX Aug 9, 2024
fee2160
Updated CHANGELOG.md
JPXKQX Aug 9, 2024
6a19a62
fix: imports
JPXKQX Aug 9, 2024
b4df296
fix: move masks.py to generate/
JPXKQX Aug 9, 2024
5d54f2f
refactor: update resolutions argument to resolution
JPXKQX Aug 9, 2024
c3dc568
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Aug 9, 2024
e641d74
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Aug 15, 2024
b9ae468
Merge 'develop' branch into 7-local-area-modelling-graphs
JPXKQX Aug 19, 2024
ddef154
fix: import annotations (py3.9)
JPXKQX Aug 19, 2024
6acfc32
tests: new tests for CutOutZarrDatasetNodes
JPXKQX Aug 19, 2024
36aa407
test: new test for KNNAreaMaskBuilder
JPXKQX Aug 19, 2024
d6a8a20
feat: mask_attr_name should be optional
JPXKQX Aug 21, 2024
7a53363
refactor: KNNAreaMAskBuilder
JPXKQX Aug 21, 2024
8383496
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Aug 23, 2024
19e38b0
docs: remove duplication
theissenhelen Sep 2, 2024
5d247de
refactor: remove second return
theissenhelen Sep 2, 2024
ffb95a0
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Sep 3, 2024
316cf7a
Merge branch '7-local-area-modelling-graphs' of github.com:ecmwf/anem…
JPXKQX Sep 3, 2024
ed6f960
docs: output folder missing
theissenhelen Sep 3, 2024
23a1a41
docs: update changelog
theissenhelen Sep 3, 2024
af1e520
docs: add PR to changelog
theissenhelen Sep 11, 2024
3b65701
docs: added missing PRs in changelog
theissenhelen Sep 11, 2024
524610a
fix: address @mchantry's comments
JPXKQX Sep 12, 2024
24ff1fc
fix: rename according to the refinement
JPXKQX Sep 12, 2024
51af55c
Merge branch '7-local-area-modelling-graphs' of github.com:ecmwf/anem…
JPXKQX Sep 12, 2024
e2601b4
fix: edge case 1 set of nodes with 1 node attribute
JPXKQX Sep 13, 2024
4762f3f
fix: delete repeated code
JPXKQX Sep 16, 2024
fac6f36
[Feature] Support node masking in edge builder (#50)
JPXKQX Sep 17, 2024
b904983
Merge branch 'develop' into 7-local-area-modelling-graphs
JPXKQX Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...HEAD)

### Added
- New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30)
- New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30)
- New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). (#30)
- Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. (#30)

## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03

### Added

- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere.
- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere

- Inspection tools: interactive plots, and distribution plots of edge & node attributes.

Expand Down
2 changes: 1 addition & 1 deletion docs/usage/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ following command:

.. code:: console

$ anemoi-graphs inspect graph.pt
$ anemoi-graphs inspect graph.pt output_plots

This will generate the following graph:

Expand Down
12 changes: 8 additions & 4 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ def generate_graph(self) -> HeteroData:
graph, nodes_cfg.get("attributes", {})
)

for edges_cfg in self.config.edges:
graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph(
graph, edges_cfg.get("attributes", {})
)
for edges_cfg in self.config.get("edges", {}):
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
177 changes: 122 additions & 55 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from scipy.sparse import coo_matrix
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.generate import hexagonal
from anemoi.graphs.generate import icosahedral
from anemoi.graphs.nodes.builder import HexNodes
from anemoi.graphs.nodes.builder import TriNodes
from anemoi.graphs.generate import hex_icosahedron
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
from anemoi.graphs.generate import tri_icosahedron
from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes
from anemoi.graphs.utils import get_grid_reference_distance

LOGGER = logging.getLogger(__name__)
Expand All @@ -26,9 +29,17 @@
class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""

def __init__(self, source_name: str, target_name: str):
def __init__(
self,
source_name: str,
target_name: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.source_name = source_name
self.target_name = target_name
self.source_mask_attr_name = source_mask_attr_name
self.target_mask_attr_name = target_mask_attr_name

@property
def name(self) -> tuple[str, str, str]:
Expand Down Expand Up @@ -117,15 +128,48 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
"""
graph = self.register_edges(graph)

if attrs_config is None:
return graph

graph = self.register_attributes(graph, attrs_config)
if attrs_config is not None:
graph = self.register_attributes(graph, attrs_config)

return graph


class KNNEdges(BaseEdgeBuilder):
class NodeMaskingMixin:
"""Mixin class for masking source/target nodes when building edges."""

def get_node_coordinates(
self, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> tuple[np.ndarray, np.ndarray]:
"""Get the node coordinates."""
source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy()

if self.source_mask_attr_name is not None:
source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()]

if self.target_mask_attr_name is not None:
target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()]

return source_coords, target_coords

def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.target_mask_attr_name is not None:
target_mask = target_nodes[self.target_mask_attr_name].squeeze()
target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0]))
adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row)

if self.source_mask_attr_name is not None:
source_mask = source_nodes[self.source_mask_attr_name].squeeze()
source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0]))
adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col)

if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None:
true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0]
adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape)

return adj_matrix


class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes KNN based edges and adds them to the graph.

Attributes
Expand All @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder):
The name of the target nodes.
num_nearest_neighbours : int
Number of nearest neighbours.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.

Methods
-------
Expand All @@ -147,22 +195,30 @@ class KNNEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
num_nearest_neighbours: int,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer"
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray):
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
"""Compute the adjacency matrix for the KNN method.

Parameters
----------
source_nodes : np.ndarray
source_nodes : NodeStorage
The source nodes.
target_nodes : np.ndarray
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
LOGGER.info(
"Using KNN-Edges (with %d nearest neighbours) between %s and %s.",
Expand All @@ -172,16 +228,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x.numpy())
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.kneighbors_graph(
target_nodes.x.numpy(),
target_coords,
n_neighbors=self.num_nearest_neighbours,
mode="distance",
).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


class CutOffEdges(BaseEdgeBuilder):
class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes cut-off based edges and adds them to the graph.

Attributes
Expand All @@ -192,6 +252,10 @@ class CutOffEdges(BaseEdgeBuilder):
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.

Methods
-------
Expand All @@ -203,8 +267,15 @@ class CutOffEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
cutoff_factor: float,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor
Expand Down Expand Up @@ -248,6 +319,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
target_nodes : NodeStorage
The target nodes.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
LOGGER.info(
"Using CutOff-Edges (with radius = %.1f km) between %s and %s.",
self.radius * EARTH_RADIUS,
Expand All @@ -256,8 +328,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


Expand All @@ -284,61 +360,52 @@ class MultiScaleEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, x_hops: int):
VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes]

def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
super().__init__(source_name, target_name)
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
assert x_hops > 0, "Number of x_hops must be positive"
self.x_hops = x_hops
self.node_type = None

def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(
source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo"
area_mask_builder=nodes.get("_area_mask_builder", None),
)
return adjmat

def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):
return nodes

source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = hex_icosahedron.add_edges_to_nx_graph(
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
)

adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
return adjmat
return nodes

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.node_type == TriNodes.__name__:
adjmat = self.adjacency_from_tri_nodes(source_nodes)
elif self.node_type == HexNodes.__name__:
adjmat = self.adjacency_from_hex_nodes(source_nodes)
if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
source_nodes = self.add_edges_from_tri_nodes(source_nodes)
elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
source_nodes = self.add_edges_from_hex_nodes(source_nodes)
else:
raise ValueError(f"Invalid node type {self.node_type}")

adjmat = self.post_process_adjmat(source_nodes, adjmat)

return adjmat
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")

def post_process_adjmat(self, nodes: NodeStorage, adjmat):
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])}
sort_func = np.vectorize(graph_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
return adjmat

def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
assert (
graph[self.source_name].node_type == TriNodes.__name__
or graph[self.source_name].node_type == HexNodes.__name__
), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}."

self.node_type = graph[self.source_name].node_type
valid_node_names = [n.__name__ for n in self.VALID_NODES]
assert (
self.node_type in valid_node_names
), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."

return super().update_graph(graph, attrs_config)
Loading
Loading