Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix message passing #95

Merged
merged 1 commit into from
Dec 12, 2024
Merged

fix message passing #95

merged 1 commit into from
Dec 12, 2024

Conversation

chhsiao93
Copy link
Contributor

@chhsiao93 chhsiao93 commented Dec 12, 2024

RFC: Fix Edge Feature Processing in InteractionNetwork

Summary

The current implementation of InteractionNetwork does not properly propagate processed edge features through the message passing flow, causing the residual connection to simply double the original edge features instead of adding processed features to the original ones.

Background

In the current implementation, the message() method processes edge features but these processed features are not properly passed through to the update() method. This causes the update() method to use the original edge features instead of the processed ones, resulting in incorrect residual connections.

Problem

When running the network:

edge_attr = tensor([[1., 1.],
                   [2., 2.]])

produces:

output = tensor([[2., 2.],
                [4., 4.]])

This indicates the output is simply doubling the input due to the residual connection, rather than adding processed features to the original ones.

Proposed Solution

Add a temporary instance variable to store processed edge features between message and update steps:

def message(self,
           x_i: torch.tensor,
           x_j: torch.tensor,
           edge_features: torch.tensor) -> torch.tensor:
    edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
    self._edge_features = self.edge_fn(edge_features)  # Store processed features
    return self._edge_features  # Return for node updates

def update(self,
          x_updated: torch.tensor,
          x: torch.tensor,
          edge_features: torch.tensor):
    x_updated = torch.cat([x_updated, x], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, self._edge_features  # Use stored processed features

Rationale

  • No need to modify the PyTorch Geometric message passing interface
  • Maintains proper feature processing flow
  • Clean and simple solution that follows PyG patterns
  • No performance overhead as the variable only exists during forward pass

Testing

New tests verify:

  1. Edge features are properly updated (not just doubled)
  2. Shapes remain correct
  3. Gradients flow properly
  4. Numerical stability is maintained

Alternatives Considered

  1. Using class member variable initialized in __init__
    • Rejected: Unnecessary persistence between forward passes
  2. Modifying PyG's message passing flow
    • Rejected: Too invasive, breaks standard patterns
  3. Using message dictionaries
    • Rejected: Not compatible with PyG's aggregation

Backward Compatibility

This change maintains the same interface and output shapes, but will produce different numerical results. Models trained with the old implementation will need to be retrained.

@kks32 kks32 merged commit 81690d2 into geoelements:v2 Dec 12, 2024
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants