diff --git a/grace/base.py b/grace/base.py index 0f9c6bf..3cc6d03 100644 --- a/grace/base.py +++ b/grace/base.py @@ -19,8 +19,8 @@ class GraphAttrs(str, enum.Enum): NODE_Y = "y" NODE_GROUND_TRUTH = "node_ground_truth" NODE_PREDICTION = "node_prediction" - NODE_FEATURES = "image_patch_features" - NODE_EMBEDDINGS = "laplacian_embeddings" + NODE_IMG_EMBEDDING = "node_image_patch_latent_embedding" + NODE_ENV_EMBEDDING = "laplacian_matrix_multiplied_embedding" NODE_CONFIDENCE = "confidence" EDGE_SOURCE = "source" EDGE_TARGET = "target" diff --git a/grace/io/schema.py b/grace/io/schema.py index 0a4e667..9eab356 100644 --- a/grace/io/schema.py +++ b/grace/io/schema.py @@ -9,7 +9,7 @@ pa.field(GraphAttrs.NODE_Y, pa.float32()), pa.field(GraphAttrs.NODE_GROUND_TRUTH, pa.int64()), pa.field(GraphAttrs.NODE_CONFIDENCE, pa.float32()), - pa.field(GraphAttrs.NODE_FEATURES, pa.list_(pa.float32())), + pa.field(GraphAttrs.NODE_IMG_EMBEDDING, pa.list_(pa.float32())), ], # metadata={"year": "2023"} ) diff --git a/grace/io/store_node_features.py b/grace/io/store_node_features.py index 55a322f..13ae020 100644 --- a/grace/io/store_node_features.py +++ b/grace/io/store_node_features.py @@ -40,8 +40,8 @@ def store_node_features_in_graph( graph = target["graph"] for _, node in graph.nodes(data=True): - node[GraphAttrs.NODE_FEATURES] = node[ - GraphAttrs.NODE_FEATURES + node[GraphAttrs.NODE_IMG_EMBEDDING] = node[ + GraphAttrs.NODE_IMG_EMBEDDING ].numpy() write_graph( diff --git a/grace/models/datasets.py b/grace/models/datasets.py index c7a4b01..b0f3066 100644 --- a/grace/models/datasets.py +++ b/grace/models/datasets.py @@ -154,8 +154,8 @@ def _x(graph: nx.Graph) -> torch.Tensor: [ np.concatenate( [ - graph.nodes[idx][GraphAttrs.NODE_FEATURES], - graph.nodes[idx][GraphAttrs.NODE_EMBEDDINGS], + graph.nodes[idx][GraphAttrs.NODE_IMG_EMBEDDING], + graph.nodes[idx][GraphAttrs.NODE_ENV_EMBEDDING], ], axis=-1, ) diff --git a/grace/models/feature_extractor.py b/grace/models/feature_extractor.py index 9b42581..a682104 100644 --- a/grace/models/feature_extractor.py +++ b/grace/models/feature_extractor.py @@ -131,7 +131,7 @@ def forward( # Run through the feature extractor model: features = self.model(bbox_image) - node_attrs[GraphAttrs.NODE_FEATURES] = features.squeeze() + node_attrs[GraphAttrs.NODE_IMG_EMBEDDING] = features.squeeze() # Back to the original shape: image = image.reshape(image_shape) @@ -192,7 +192,7 @@ def forward( # Run through these shortlisted torch methods: methods = [torch.mean, torch.std, torch.min, torch.max] features = torch.Tensor([m(bbox_image) for m in methods]) - node_attrs[GraphAttrs.NODE_FEATURES] = features.squeeze() + node_attrs[GraphAttrs.NODE_IMG_EMBEDDING] = features.squeeze() # Back to the original shape: image = image.reshape(image_shape) diff --git a/grace/models/graph_laplacian.py b/grace/models/graph_laplacian.py index 11dbba0..2fe4520 100644 --- a/grace/models/graph_laplacian.py +++ b/grace/models/graph_laplacian.py @@ -18,7 +18,7 @@ def calculate_graph_laplacian(self) -> None: def extract_node_features(self) -> npt.NDArray: feature_matrix = np.stack( [ - n[GraphAttrs.NODE_FEATURES] + n[GraphAttrs.NODE_IMG_EMBEDDING] for _, n in self.graph.nodes(data=True) ], axis=0, @@ -35,7 +35,7 @@ def transform_feature_embeddings(self) -> nx.Graph: # Append node attributes to the graph: for node_idx, node in self.graph.nodes(data=True): - node[GraphAttrs.NODE_EMBEDDINGS] = torch.Tensor( + node[GraphAttrs.NODE_ENV_EMBEDDING] = torch.Tensor( embedded_matrix[:, node_idx] ) diff --git a/tests/conftest.py b/tests/conftest.py index 1274a5b..f5956a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,7 @@ def create_nodes( { GraphAttrs.NODE_X: node_coords[:, 0], GraphAttrs.NODE_Y: node_coords[:, 1], - GraphAttrs.NODE_FEATURES: features, + GraphAttrs.NODE_IMG_EMBEDDING: features, GraphAttrs.NODE_GROUND_TRUTH: node_ground_truth, GraphAttrs.NODE_CONFIDENCE: rng.uniform( size=(num_nodes), @@ -104,7 +104,7 @@ def simple_graph_dataframe(default_rng) -> pd.DataFrame: { GraphAttrs.NODE_X: [0.0, 1.0, 2.0], GraphAttrs.NODE_Y: [0.0, 1.0, 0.0], - GraphAttrs.NODE_FEATURES: features, + GraphAttrs.NODE_IMG_EMBEDDING: features, GraphAttrs.NODE_GROUND_TRUTH: [1, 1, 1], GraphAttrs.NODE_CONFIDENCE: [0.9, 0.1, 0.8], } diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 448579c..c6859fb 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -40,7 +40,7 @@ def test_feature_extractor_forward(self, bbox_size, model, vars): for _, node_attrs in graph_out.nodes.data(): x, y = node_attrs[GraphAttrs.NODE_X], node_attrs[GraphAttrs.NODE_Y] - features = node_attrs[GraphAttrs.NODE_FEATURES] + features = node_attrs[GraphAttrs.NODE_IMG_EMBEDDING] x_low = int(x - bbox_size[0] / 2) x_box = slice(x_low, x_low + bbox_size[0])