From 7be6ceab4ab78de40d60869f42149384abbfa7cc Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 15 Aug 2024 16:37:57 +0000 Subject: [PATCH] feat: add table with statistics of the attributes --- src/anemoi/graphs/descriptor.py | 75 ++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/src/anemoi/graphs/descriptor.py b/src/anemoi/graphs/descriptor.py index 56e964d..9412f8f 100644 --- a/src/anemoi/graphs/descriptor.py +++ b/src/anemoi/graphs/descriptor.py @@ -1,6 +1,7 @@ import math from itertools import chain from pathlib import Path +from typing import Optional from typing import Union import torch @@ -95,7 +96,53 @@ def get_edge_summary(self) -> list[list]: ) return edge_summary - def describe(self) -> None: + def get_node_attribute_table(self) -> list[list]: + node_attributes = [] + for node_name, node_store in self.graph.node_items(): + node_attr_names = node_store.node_attrs() + node_attr_names.remove("x") # Remove the coordinates from statistics table + for node_attr_name in node_attr_names: + node_attributes.append( + [ + "Node", + node_name, + node_attr_name, + node_store[node_attr_name].min().item(), + node_store[node_attr_name].mean().item(), + node_store[node_attr_name].max().item(), + node_store[node_attr_name].std().item(), + ] + ) + return node_attributes + + def get_edge_attribute_table(self) -> list[list]: + edge_attributes = [] + for (source_name, _, target_name), edge_store in self.graph.edge_items(): + edge_attr_names = edge_store.edge_attrs() + edge_attr_names.remove("edge_index") # Remove the edge index from statistics table + for edge_attr_name in edge_attr_names: + edge_attributes.append( + [ + "Edge", + f"{source_name}-->{target_name}", + edge_attr_name, + edge_store[edge_attr_name].min().item(), + edge_store[edge_attr_name].mean().item(), + edge_store[edge_attr_name].max().item(), + edge_store[edge_attr_name].std().item(), + ] + ) + + return edge_attributes + + def get_attribute_table(self) -> list[list]: + """Get a table with the attributes of the graph.""" + attribute_table = [] + attribute_table.extend(self.get_node_attribute_table()) + attribute_table.extend(self.get_edge_attribute_table()) + return attribute_table + + def describe(self, show_attribute_distributions: Optional[bool] = True) -> None: """Describe the graph.""" print() print(f"📦 Path : {self.path}") @@ -140,5 +187,31 @@ def describe(self) -> None: margin=3, ) ) + print() + if show_attribute_distributions: + print() + print("📊 Attribute distributions") + print() + print( + table( + self.get_attribute_table(), + header=[ + "Type", + "Source", + "Name", + "Min.", + "Mean", + "Max.", + "Std. dev.", + ], + align=["<", "<", ">", ">", ">", ">", ">"], + margin=3, + ) + ) + print() print("🔋 Graph ready.") print() + + +if __name__ == "__main__": + GraphDescriptor("graph.pt").describe()