diff --git a/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py b/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py
index 091158b6..b0d26381 100644
--- a/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py
+++ b/src/predictions/profiles_mlcorelib/py_native/id_stitcher/cluster_report.py
@@ -52,6 +52,7 @@ def __init__(
self.counter = 0
self.logger = logger
self.color_map = {}
+ self.cluster_specific_id_types = set()
def get_edges_data(self, node_id: str) -> pd.DataFrame:
cluster_query_template = """
@@ -80,9 +81,11 @@ def get_edges_data(self, node_id: str) -> pd.DataFrame:
def create_graph_with_metadata(self, edges):
G = nx.Graph()
for _, row in edges.iterrows():
- G.add_edge(row["id1"], row["id2"])
- G.nodes[row["id1"]]["id_type"] = row["id1_type"]
- G.nodes[row["id2"]]["id_type"] = row["id2_type"]
+ node1 = (row["id1"], row["id1_type"])
+ node2 = (row["id2"], row["id2_type"])
+ G.add_edge(node1, node2)
+ G.nodes[node1]["id_type"] = row["id1_type"]
+ G.nodes[node2]["id_type"] = row["id2_type"]
return G
def compute_graph_metrics(
@@ -190,13 +193,18 @@ def _visualize_small_graph(self, G: nx.Graph, file_path: str):
for node, attrs in G.nodes(data=True):
color = self._get_node_color(attrs["id_type"], degrees[node], max_degree)
net.add_node(
- node,
+ f"{node[0]}
{node[1]}", # complex nodes can't be used as ids in pyvis but they can be used in networkx
+ label=node[0],
color=color,
- title=f"ID: {node}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}",
+ title=f"ID: {node[0]}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}",
)
for source, target in G.edges():
- net.add_edge(source, target, color="#888888")
+ net.add_edge(
+ f"{source[0]}
{source[1]}",
+ f"{target[0]}
{target[1]}",
+ color="#888888",
+ )
net.set_options(
"""
@@ -262,18 +270,24 @@ def _visualize_large_graph(self, G: nx.Graph, file_path: str):
color = self._get_node_color(attrs["id_type"], degrees[node], max_degree)
size = 5 + (degrees[node] / max_degree) * 15 # Smaller node sizes
net.add_node(
- node,
+ f"{node[0]}
{node[1]}", # complex nodes can't be used as ids in pyvis but they can be used in networkx
+ label=node[0],
x=int(x),
y=int(y),
physics=False, # Disable physics for pre-positioned nodes
size=size,
color=color,
- title=f"ID: {node}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}",
+ title=f"ID: {node[0]}\nID-Type: {attrs['id_type']}\nDegree: {degrees[node]}",
)
print("Adding edges...")
for source, target in G.edges():
- net.add_edge(source, target, color="#88888844", width=0.5)
+ net.add_edge(
+ f"{source[0]}
{source[1]}",
+ f"{target[0]}
{target[1]}",
+ color="#88888844",
+ width=0.5,
+ )
net.set_options(
"""
@@ -317,10 +331,6 @@ def _visualize_large_graph(self, G: nx.Graph, file_path: str):
def _pre_compute_graph_info(self, G: nx.Graph):
degrees = dict(G.degree())
max_degree = max(degrees.values()) if degrees else 0
- id_types = set(nx.get_node_attributes(G, "id_type").values())
-
- if not self.color_map:
- self.color_map = self._generate_color_map(id_types)
return degrees, max_degree
def _initialise_network(self):
@@ -345,6 +355,7 @@ def _add_legend_to_file(self, file_path):
"""
for id_type, color in self.color_map.items()
+ if id_type in self.cluster_specific_id_types
]
legend_html = f"""
@@ -382,6 +393,9 @@ def run(self):
"You can explore specific clusters by entering an ID to see how the other ids are all connected and the cluster is formed."
)
print("The ID can be either the main ID or any other ID type.")
+ self.color_map = self._generate_color_map(
+ self.table_report.analysis_results["node_types"]
+ )
output_dir = os.path.join(os.getcwd(), "graph_outputs")
while True:
user_input = self.reader.get_input(
@@ -394,6 +408,9 @@ def run(self):
metrics, G = self._analyse_cluster(user_input)
if metrics is None:
continue
+ self.cluster_specific_id_types = set(
+ nx.get_node_attributes(G, "id_type").values()
+ )
cluster_summary = self.get_cluster_summary(metrics)
self.counter += 1
if metrics.get("num_nodes", 0) > 1000: