Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idea for how to define neural-networks operating on graph #35

Open
leifdenby opened this issue May 22, 2024 · 4 comments
Open

Idea for how to define neural-networks operating on graph #35

leifdenby opened this issue May 22, 2024 · 4 comments

Comments

@leifdenby
Copy link
Member

leifdenby commented May 22, 2024

I've been thinking about whether we could construct the different neural-networks which update the graph. The ideas below are very tentative and I may also have fundamentally misunderstood parts of how this is supposed to work, so all feedback is very welcome.

To my mind there are two types of neural networks we used: 1) MLPs for creating embeddings of node/edge features into a common-sized latent-space across all embeddings and 2) for updating node embeddings using message-passing.

What I would like to achieve is code that:

  • is closer to the mathematical notation used in the neural-lam publication for the expressions that do the embedding and message-passing, i.e. shows which nodes/edges are being operated on
  • allows me to easily see in one place the complete set of neural networks used in a given architecture
  • is flexible by making easy to create new message passing operations

I think there are basically three steps to this:

  1. Define which embedding networks to create, the number of embedders will depend on nodes/edges share the same features or not
  2. Define the message passing operations, i.e. which nodes the message passing communicates between (giving each operation a unique identifier)
  3. Define the order of the message-passing order

Below is my code example that tries to encapsulate all the information I think is needed to later in the code actually instantiate the neural-network models that do the numerical operations.

import pytorch_lightning as pl


class NewGraphModel(pl.LightningModule):
    def __init__(self):
        n_mesh_node_features = 3
        n_grid_node_features = 18
        n_edge_features = 3
        n_hidden_features = 64

        # create node and edge feature embedding networks, all will project
        # into an embedding space of size n_hidden_features
        embedders = [
            dict(
                node=dict(component="mesh"),
                n_input_features=n_mesh_node_features,
            ),
            dict(
                node=dict(component="grid"),
                n_input_features=n_grid_node_features,
            ),
            # use the same edge embedding network for all edges, assuming
            # that all edges have the same set of features
            dict(edge=dict(), n_input_features=n_edge_features),
        ]

        # create message-passing networks that update node embeddings,
        # these all assume that the embedding vectors are of the same size
        message_passers = dict(
            g2m=dict(src=dict(component="grid"), dst=dict(component="mesh")),
            m2m_up_1to2=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=2),
            ),
            m2m_down_2to1=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_1=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_2=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=2),
            ),
            m2g=dict(src=dict(component="mesh"), dst=dict(component="grid")),
        )

        # define the order in which messages are passed
        # here we do the up/down twice before decoding back to the grid
        message_passing_order = [
            "g2m",
            # m2m pass 1
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            # m2m pass 2
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            "m2g",
        ]

A few notes on this to explain what is going on:

  • the three sections implement those three steps 1) embedders, 2) message passers, 3) message passing order
  • when defining the embedders or the message-passers then the arguments in the dict in effect define filters that select edges
  • I have made the message passing order explicit because this would create a single place to refer to for the message passing operations ordering. This feels a bit like nn.Sequential but is isn't of course because its not like the output from one step feeds into the next.

I hope this isn't total nonsense @joeloskarsson and @sadamov 😆 just trying to get the ball rolling

@joeloskarsson
Copy link
Collaborator

So there's a few things that you bring up here, and some are almost orthogonal.

W.r.t embedders: I agree that this could be a bit more structured. In some code I've been working with I have a function embedd_all that handles all embedding one node-set/edge-set at a time and that gives a good overview.

I struggle a bit to understand the other parts, so would be grateful for some clarifications.

is closer to the mathematical notation used in the neural-lam publication for the expressions that do the embedding and message-passing, i.e. shows which nodes/edges are being operated on

Could you explain how the current implementation looks different from the equations in the paper? I look for example at something like

grid_rep = self.m2g_gnn(
mesh_rep, grid_rep, m2g_emb_expanded
) # (B, num_grid_nodes, d_h)

and to me this tells exactly which node sets messages are passed between and what edges to use. Is the idea that you want the logic of the forward pass of the model to be defined in the init function? To me that it makes sense to follow the practice of instantiating the network blocks in init and using them in forward. However, something I think we should change is to break things up into separate nn.Modules (e.g. the encoder, processor and decoder parts), so that the forward pass actually sits in a forward function, rather than something like predict_step.

allows me to easily see in one place the complete set of neural networks used in a given architecture

Have you thought about how this should work with the class hierarchy? I think that is something I fail to see. It makes sense when you write out the NewGraphModel here, but in practice you will never have a class like this where all the GNNs and MLPs are defined in the same init function. As a concrete example: The ARModel does not know if it is working with a hierarchical graph or not. Therefore it will not know if it should create 1 mesh edge embedders or $L$ (= number of levels) such embedders.

Connected to above, my understanding of the NewGraphModel example is that you create lists and dicts that describe all network components and then you use these as a blueprint to instantiate those components. Is that correct? Why would this be less convoluted and more understandable than just instantiating the components directly? Do you have an idea of how this would work with the class hierarchy? Would all subclasses append to e.g. self.embedders and what class is responsible for actually triggering the instantiation based on this blueprint?

is flexible by making easy to create new message passing operations

I interpret this as being able to create new GNN layers, is that correct? I think that is very important and something we should think about. If we keep the current function signature from the forward of the InteractionNetworks this only comes down to changing the instantiation of GNN layers. I have done a little bit of this in our probabilistic modelling and it is very easy to swap out different GNN layer classes.

@leifdenby
Copy link
Member Author

W.r.t embedders: I agree that this could be a bit more structured. In some code I've been working with I have a function embedd_all that handles all embedding one node-set/edge-set at a time and that gives a good overview.

Ok, could you point me to this? Just so we're on the same page: what I thought would be nice to have was a single point where we in a sense register what embedding networks that will be constructed. This could be achieve in different ways, for example

  1. have a single collection of embedding "blueprints" that define all embedding networks to construct. I hope I've understand this term the way you use it, but my use here would be a definition of the number of input and output features, the type of the embedding (number of layers, width, or maybe we just come up with some common set of names to describe this) and an identifier for each (e.g. "graph_nodes")

  2. define a collection for the constructing embedding networks to be stored in, this could simply be a dictionary with each key being the identifier for a given embedding network, and we encourage people to put their embedding networks in this dict

I would prefer option 1 as this would allow us to easily print what embedding networks are initiated, easy to understand how to add more, and would enforce that they work identically.

In container type for these graph based models this could be implemented with something like:

from torch import nn


class EncodeProcessDecodeGraph:
    def __init__(self):
        self._embedding_blueprints = {}
        self._embedding_networks = {}

    def _register_embedder(
        self, identifier, n_features_in, n_features_out, kind
    ):
        self._embedding_blueprints[identifier] = dict(
            n_features_in=n_features_in,
            n_features_out=n_features_out,
            kind=kind,
        )

    def _construct_embedders(self):
        for identifier, blueprint in self._embedding_blueprints.items():
            n_in = blueprint["n_features_in"]
            n_out = blueprint["n_features_out"]
            if blueprint["kind"] == "linear_single":
                self._embedding_networks[identifier] = nn.Linear(n_in, n_out)
            else:
                raise ValueError(f"Unknown kind: {blueprint['kind']}")


class KeislerGraph(EncodeProcessDecodeGraph):
    def __init__(
        self, hidden_dim_size=512, n_grid_features=10, n_edge_features=2
    ):
        super().__init__()

        self._register_embedder(
            identifier="grid_node",
            n_features_in=n_grid_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )
        self._register_embedder(
            identifier="g2m_and_m2g_edge",
            n_features_in=n_edge_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )
        self._register_embedder(
            identifier="m2m_edge",
            n_features_in=n_edge_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )
        
        self._construct_embedders()

I will create a separate comment for the message-passing networks

@joeloskarsson
Copy link
Collaborator

The embedd_all function that I mentioned is more about how the embedders are then applied to the input, so quite orthogonal to this. It is in the code for the ensemble model so not on Github yet. But I hope to have that pushed today, then I'll link it here.

Overall I think this looks nice. This could act almost as a wrapper for utils.make_mlp, but specific to MLPs that are embedders, so we can better keep track of them and make sure they are constructed in similar ways. There will be some work needed to make this fit into the model class hierarchy (existing or proposed new one), but I think it should be doable.

One thing I am wondering though: Could we not just immediately create each embedder, instead of registering them first and then having a separate call for constructing them? If we have something like _create_embedder it could both create the embedder and store it in self._embedding_networks immediately. That way we still keep track of all embedders in the same way, but avoids an additional function call that requires some understanding of the whole register->construct setup. I think that would also play more nicely with the model class hierarcy, as there will be no question which class is responsible for calling self._construct_embedders().

@joeloskarsson
Copy link
Collaborator

Here is the embedd_all function:

def embedd_all(self, prev_state, prev_prev_state, forcing):
"""
embed all node and edge representations
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
forcing: (B, num_grid_nodes, forcing_dim)
Returns:
grid_emb: (B, num_grid_nodes, d_h)
graph_embedding: dict with entries of shape (B, *, d_h)
"""
batch_size = prev_state.shape[0]
grid_features = torch.cat(
(
prev_prev_state,
prev_state,
forcing,
self.expand_to_batch(self.grid_static_features, batch_size),
),
dim=-1,
) # (B, num_grid_nodes, grid_dim)
grid_emb = self.grid_prev_embedder(grid_features)
# (B, num_grid_nodes, d_h)
# Graph embedding
graph_emb = {
"g2m": self.expand_to_batch(
self.g2m_embedder(self.g2m_features), batch_size
), # (B, M_g2m, d_h)
"m2g": self.expand_to_batch(
self.m2g_embedder(self.m2g_features), batch_size
), # (B, M_m2g, d_h)
}
if self.hierarchical_graph:
graph_emb["mesh"] = [
self.expand_to_batch(emb(node_static_features), batch_size)
for emb, node_static_features in zip(
self.mesh_embedders,
self.mesh_static_features,
)
] # each (B, num_mesh_nodes[l], d_h)
if self.embedd_m2m:
graph_emb["m2m"] = [
self.expand_to_batch(emb(edge_feat), batch_size)
for emb, edge_feat in zip(
self.m2m_embedders, self.m2m_features
)
]
else:
# Need a placeholder otherwise, just use raw features
graph_emb["m2m"] = list(self.m2m_features)
graph_emb["mesh_up"] = [
self.expand_to_batch(emb(edge_feat), batch_size)
for emb, edge_feat in zip(
self.mesh_up_embedders, self.mesh_up_features
)
]
graph_emb["mesh_down"] = [
self.expand_to_batch(emb(edge_feat), batch_size)
for emb, edge_feat in zip(
self.mesh_down_embedders, self.mesh_down_features
)
]
else:
graph_emb["mesh"] = self.expand_to_batch(
self.mesh_embedder(self.mesh_static_features), batch_size
) # (B, num_mesh_nodes, d_h)
graph_emb["m2m"] = self.expand_to_batch(
self.m2m_embedder(self.m2m_features), batch_size
) # (B, M_m2m, d_h)
return grid_emb, graph_emb

But again, that is a bit orthogonal since it is about how the embedders are applied to grid-input + graph, rather than how they are created. Perhaps of more interest is how the embedders are created in that class:

# Feature embedders for grid
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
self.grid_prev_embedder = utils.make_mlp(
[self.grid_dim] + self.mlp_blueprint_end
) # For states up to t-1
self.grid_current_embedder = utils.make_mlp(
[grid_current_dim] + self.mlp_blueprint_end
) # For states including t
# Embedders for mesh
self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)
if self.hierarchical_graph:
# Print some useful info
print("Loaded hierarchical graph with structure:")
level_mesh_sizes = [
mesh_feat.shape[0] for mesh_feat in self.mesh_static_features
]
self.num_mesh_nodes = level_mesh_sizes[-1]
num_levels = len(self.mesh_static_features)
for level_index, level_mesh_size in enumerate(level_mesh_sizes):
same_level_edges = self.m2m_features[level_index].shape[0]
print(
f"level {level_index} - {level_mesh_size} nodes, "
f"{same_level_edges} same-level edges"
)
if level_index < (num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index+1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
# Embedders
# Assume all levels have same static feature dimensionality
mesh_dim = self.mesh_static_features[0].shape[1]
m2m_dim = self.m2m_features[0].shape[1]
mesh_up_dim = self.mesh_up_features[0].shape[1]
mesh_down_dim = self.mesh_down_features[0].shape[1]
# Separate mesh node embedders for each level
self.mesh_embedders = torch.nn.ModuleList(
[
utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
for _ in range(num_levels)
]
)
self.mesh_up_embedders = torch.nn.ModuleList(
[
utils.make_mlp([mesh_up_dim] + self.mlp_blueprint_end)
for _ in range(num_levels - 1)
]
)
self.mesh_down_embedders = torch.nn.ModuleList(
[
utils.make_mlp([mesh_down_dim] + self.mlp_blueprint_end)
for _ in range(num_levels - 1)
]
)
# If not using any processor layers, no need to embed m2m
self.embedd_m2m = (
max(
args.prior_processor_layers,
args.encoder_processor_layers,
args.processor_layers,
)
> 0
)
if self.embedd_m2m:
self.m2m_embedders = torch.nn.ModuleList(
[
utils.make_mlp([m2m_dim] + self.mlp_blueprint_end)
for _ in range(num_levels)
]
)
else:
self.num_mesh_nodes, mesh_static_dim = (
self.mesh_static_features.shape
)
print(
f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes}"
f"nodes ({self.num_grid_nodes} grid, "
f"{self.num_mesh_nodes} mesh)"
)
mesh_static_dim = self.mesh_static_features.shape[1]
self.mesh_embedder = utils.make_mlp(
[mesh_static_dim] + self.mlp_blueprint_end
)
m2m_dim = self.m2m_features.shape[1]
self.m2m_embedder = utils.make_mlp(
[m2m_dim] + self.mlp_blueprint_end
)

This is at least collected all in one place, but could be even nicer with something like what's proposed above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants