diff --git a/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py b/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py index 091158b6..b0d26381 100644 --- a/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py +++ b/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py @@ -52,6 +52,7 @@ def __init__( self.counter = 0 self.logger = logger self.color_map = {} + self.cluster_specific_id_types = set() def get_edges_data(self, node_id: str) -> pd.DataFrame: cluster_query_template = """ @@ -80,9 +81,11 @@ def get_edges_data(self, node_id: str) -> pd.DataFrame: def create_graph_with_metadata(self, edges): G = nx.Graph() for _, row in edges.iterrows(): - G.add_edge(row["id1"], row["id2"]) - G.nodes[row["id1"]]["id_type"] = row["id1_type"] - G.nodes[row["id2"]]["id_type"] = row["id2_type"] + node1 = (row["id1"], row["id1_type"]) + node2 = (row["id2"], row["id2_type"]) + G.add_edge(node1, node2) + G.nodes[node1]["id_type"] = row["id1_type"] + G.nodes[node2]["id_type"] = row["id2_type"] return G def compute_graph_metrics( @@ -190,13 +193,18 @@ def _visualize_small_graph(self, G: nx.Graph, file_path: str): for node, attrs in G.nodes(data=True): color = self._get_node_color(attrs["id_type"], degrees[node], max_degree) net.add_node( - node, + f"{node[0]}
{node[1]}", # complex nodes can't be used as ids in pyvis but they can be used in networkx + label=node[0], color=color, - title=f"ID: {node}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}", + title=f"ID: {node[0]}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}", ) for source, target in G.edges(): - net.add_edge(source, target, color="#888888") + net.add_edge( + f"{source[0]}
{source[1]}", + f"{target[0]}
{target[1]}", + color="#888888", + ) net.set_options( """ @@ -262,18 +270,24 @@ def _visualize_large_graph(self, G: nx.Graph, file_path: str): color = self._get_node_color(attrs["id_type"], degrees[node], max_degree) size = 5 + (degrees[node] / max_degree) * 15 # Smaller node sizes net.add_node( - node, + f"{node[0]}
{node[1]}", # complex nodes can't be used as ids in pyvis but they can be used in networkx + label=node[0], x=int(x), y=int(y), physics=False, # Disable physics for pre-positioned nodes size=size, color=color, - title=f"ID: {node}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}", + title=f"ID: {node[0]}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}", ) print("Adding edges...") for source, target in G.edges(): - net.add_edge(source, target, color="#88888844", width=0.5) + net.add_edge( + f"{source[0]}
{source[1]}", + f"{target[0]}
{target[1]}", + color="#88888844", + width=0.5, + ) net.set_options( """ @@ -317,10 +331,6 @@ def _visualize_large_graph(self, G: nx.Graph, file_path: str): def _pre_compute_graph_info(self, G: nx.Graph): degrees = dict(G.degree()) max_degree = max(degrees.values()) if degrees else 0 - id_types = set(nx.get_node_attributes(G, "id_type").values()) - - if not self.color_map: - self.color_map = self._generate_color_map(id_types) return degrees, max_degree def _initialise_network(self): @@ -345,6 +355,7 @@ def _add_legend_to_file(self, file_path): """ for id_type, color in self.color_map.items() + if id_type in self.cluster_specific_id_types ] legend_html = f""" @@ -382,6 +393,9 @@ def run(self): "You can explore specific clusters by entering an ID to see how the other ids are all connected and the cluster is formed." ) print("The ID can be either the main ID or any other ID type.") + self.color_map = self._generate_color_map( + self.table_report.analysis_results["node_types"] + ) output_dir = os.path.join(os.getcwd(), "graph_outputs") while True: user_input = self.reader.get_input( @@ -394,6 +408,9 @@ def run(self): metrics, G = self._analyse_cluster(user_input) if metrics is None: continue + self.cluster_specific_id_types = set( + nx.get_node_attributes(G, "id_type").values() + ) cluster_summary = self.get_cluster_summary(metrics) self.counter += 1 if metrics.get("num_nodes", 0) > 1000: