Skip to content

Commit

Permalink
feat: add table with statistics of the attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Aug 15, 2024
1 parent 2906fd1 commit 7be6cea
Showing 1 changed file with 74 additions and 1 deletion.
75 changes: 74 additions & 1 deletion src/anemoi/graphs/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from itertools import chain
from pathlib import Path
from typing import Optional
from typing import Union

import torch
Expand Down Expand Up @@ -95,7 +96,53 @@ def get_edge_summary(self) -> list[list]:
)
return edge_summary

def describe(self) -> None:
def get_node_attribute_table(self) -> list[list]:
node_attributes = []
for node_name, node_store in self.graph.node_items():
node_attr_names = node_store.node_attrs()
node_attr_names.remove("x") # Remove the coordinates from statistics table
for node_attr_name in node_attr_names:
node_attributes.append(
[
"Node",
node_name,
node_attr_name,
node_store[node_attr_name].min().item(),
node_store[node_attr_name].mean().item(),
node_store[node_attr_name].max().item(),
node_store[node_attr_name].std().item(),
]
)
return node_attributes

def get_edge_attribute_table(self) -> list[list]:
edge_attributes = []
for (source_name, _, target_name), edge_store in self.graph.edge_items():
edge_attr_names = edge_store.edge_attrs()
edge_attr_names.remove("edge_index") # Remove the edge index from statistics table
for edge_attr_name in edge_attr_names:
edge_attributes.append(
[
"Edge",
f"{source_name}-->{target_name}",
edge_attr_name,
edge_store[edge_attr_name].min().item(),
edge_store[edge_attr_name].mean().item(),
edge_store[edge_attr_name].max().item(),
edge_store[edge_attr_name].std().item(),
]
)

return edge_attributes

def get_attribute_table(self) -> list[list]:
"""Get a table with the attributes of the graph."""
attribute_table = []
attribute_table.extend(self.get_node_attribute_table())
attribute_table.extend(self.get_edge_attribute_table())
return attribute_table

def describe(self, show_attribute_distributions: Optional[bool] = True) -> None:
"""Describe the graph."""
print()
print(f"📦 Path : {self.path}")
Expand Down Expand Up @@ -140,5 +187,31 @@ def describe(self) -> None:
margin=3,
)
)
print()
if show_attribute_distributions:
print()
print("📊 Attribute distributions")
print()
print(
table(
self.get_attribute_table(),
header=[
"Type",
"Source",
"Name",
"Min.",
"Mean",
"Max.",
"Std. dev.",
],
align=["<", "<", ">", ">", ">", ">", ">"],
margin=3,
)
)
print()
print("🔋 Graph ready.")
print()


if __name__ == "__main__":
GraphDescriptor("graph.pt").describe()

0 comments on commit 7be6cea

Please sign in to comment.