Skip to content

Commit

Permalink
Merge pull request #332 from alan-turing-institute/refactor-node-feat…
Browse files Browse the repository at this point in the history
…ures

Refactor node features ✍️
  • Loading branch information
KristinaUlicna committed Oct 26, 2023
2 parents f64ed28 + e616b2d commit 0be61a3
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions grace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class GraphAttrs(str, enum.Enum):
NODE_Y = "y"
NODE_GROUND_TRUTH = "node_ground_truth"
NODE_PREDICTION = "node_prediction"
NODE_FEATURES = "image_patch_features"
NODE_EMBEDDINGS = "laplacian_embeddings"
NODE_IMG_EMBEDDING = "node_image_patch_latent_embedding"
NODE_ENV_EMBEDDING = "laplacian_matrix_multiplied_embedding"
NODE_CONFIDENCE = "confidence"
EDGE_SOURCE = "source"
EDGE_TARGET = "target"
Expand Down
2 changes: 1 addition & 1 deletion grace/io/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
pa.field(GraphAttrs.NODE_Y, pa.float32()),
pa.field(GraphAttrs.NODE_GROUND_TRUTH, pa.int64()),
pa.field(GraphAttrs.NODE_CONFIDENCE, pa.float32()),
pa.field(GraphAttrs.NODE_FEATURES, pa.list_(pa.float32())),
pa.field(GraphAttrs.NODE_IMG_EMBEDDING, pa.list_(pa.float32())),
],
# metadata={"year": "2023"}
)
Expand Down
4 changes: 2 additions & 2 deletions grace/io/store_node_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def store_node_features_in_graph(
graph = target["graph"]

for _, node in graph.nodes(data=True):
node[GraphAttrs.NODE_FEATURES] = node[
GraphAttrs.NODE_FEATURES
node[GraphAttrs.NODE_IMG_EMBEDDING] = node[
GraphAttrs.NODE_IMG_EMBEDDING
].numpy()

write_graph(
Expand Down
4 changes: 2 additions & 2 deletions grace/models/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def _x(graph: nx.Graph) -> torch.Tensor:
[
np.concatenate(
[
graph.nodes[idx][GraphAttrs.NODE_FEATURES],
graph.nodes[idx][GraphAttrs.NODE_EMBEDDINGS],
graph.nodes[idx][GraphAttrs.NODE_IMG_EMBEDDING],
graph.nodes[idx][GraphAttrs.NODE_ENV_EMBEDDING],
],
axis=-1,
)
Expand Down
4 changes: 2 additions & 2 deletions grace/models/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(

# Run through the feature extractor model:
features = self.model(bbox_image)
node_attrs[GraphAttrs.NODE_FEATURES] = features.squeeze()
node_attrs[GraphAttrs.NODE_IMG_EMBEDDING] = features.squeeze()

# Back to the original shape:
image = image.reshape(image_shape)
Expand Down Expand Up @@ -192,7 +192,7 @@ def forward(
# Run through these shortlisted torch methods:
methods = [torch.mean, torch.std, torch.min, torch.max]
features = torch.Tensor([m(bbox_image) for m in methods])
node_attrs[GraphAttrs.NODE_FEATURES] = features.squeeze()
node_attrs[GraphAttrs.NODE_IMG_EMBEDDING] = features.squeeze()

# Back to the original shape:
image = image.reshape(image_shape)
Expand Down
4 changes: 2 additions & 2 deletions grace/models/graph_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def calculate_graph_laplacian(self) -> None:
def extract_node_features(self) -> npt.NDArray:
feature_matrix = np.stack(
[
n[GraphAttrs.NODE_FEATURES]
n[GraphAttrs.NODE_IMG_EMBEDDING]
for _, n in self.graph.nodes(data=True)
],
axis=0,
Expand All @@ -35,7 +35,7 @@ def transform_feature_embeddings(self) -> nx.Graph:

# Append node attributes to the graph:
for node_idx, node in self.graph.nodes(data=True):
node[GraphAttrs.NODE_EMBEDDINGS] = torch.Tensor(
node[GraphAttrs.NODE_ENV_EMBEDDING] = torch.Tensor(
embedded_matrix[:, node_idx]
)

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_nodes(
{
GraphAttrs.NODE_X: node_coords[:, 0],
GraphAttrs.NODE_Y: node_coords[:, 1],
GraphAttrs.NODE_FEATURES: features,
GraphAttrs.NODE_IMG_EMBEDDING: features,
GraphAttrs.NODE_GROUND_TRUTH: node_ground_truth,
GraphAttrs.NODE_CONFIDENCE: rng.uniform(
size=(num_nodes),
Expand Down Expand Up @@ -104,7 +104,7 @@ def simple_graph_dataframe(default_rng) -> pd.DataFrame:
{
GraphAttrs.NODE_X: [0.0, 1.0, 2.0],
GraphAttrs.NODE_Y: [0.0, 1.0, 0.0],
GraphAttrs.NODE_FEATURES: features,
GraphAttrs.NODE_IMG_EMBEDDING: features,
GraphAttrs.NODE_GROUND_TRUTH: [1, 1, 1],
GraphAttrs.NODE_CONFIDENCE: [0.9, 0.1, 0.8],
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_feature_extractor_forward(self, bbox_size, model, vars):

for _, node_attrs in graph_out.nodes.data():
x, y = node_attrs[GraphAttrs.NODE_X], node_attrs[GraphAttrs.NODE_Y]
features = node_attrs[GraphAttrs.NODE_FEATURES]
features = node_attrs[GraphAttrs.NODE_IMG_EMBEDDING]

x_low = int(x - bbox_size[0] / 2)
x_box = slice(x_low, x_low + bbox_size[0])
Expand Down

0 comments on commit 0be61a3

Please sign in to comment.