Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pos Node Type for Diffusion Modeling Support in HydraGNN #296

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,15 @@ def _multihead(self):
head_NN.append(self.convs_node_output[inode_feature])
head_NN.append(self.batch_norms_node_output[inode_feature])
inode_feature += 1

else:
raise ValueError(
"Unknown head NN structure for node features"
+ self.node_NN_type
+ "; currently only support 'mlp', 'mlp_per_node' or 'conv' (can be set with config['NeuralNetwork']['Architecture']['output_heads']['node']['type'], e.g., ./examples/ci_multihead.json)"
)
elif self.head_type[ihead] == "pos":
head_NN = torch.nn.Identity()
else:
raise ValueError(
"Unknown head type"
Expand Down Expand Up @@ -316,7 +319,6 @@ def forward(self, data):
x = self.activation_function(feat_layer(c))

#### multi-head decoder part####
# shared dense layers for graph level output
if data.batch is None:
x_graph = x.mean(dim=0, keepdim=True)
else:
Expand All @@ -331,7 +333,7 @@ def forward(self, data):
output_head = headloc(x_graph_head)
outputs.append(output_head[:, :head_dim])
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
elif type_head == "node":
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, pos = conv(x=x, pos=pos, **conv_args)
Expand All @@ -342,6 +344,29 @@ def forward(self, data):
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
outputs_var.append(x_node[:, head_dim:] ** 2)

elif type_head == "pos":
if self.equivariance:
x_node = pos - data.pos
sg_num_nodes = [d.num_nodes for d in data.to_data_list()]
com_ten = []
place = 0
for sgnn in sg_num_nodes:
sg_x_node = x_node[place : place + sgnn]
com_ten.append(
sg_x_node.mean(dim=0, keepdim=True).tile(sgnn, 1)
)
place += sgnn
com_ten = torch.cat(com_ten, dim=0)
x_node = x_node - com_ten
else:
x_node = pos
# TODO: implement output_var for this type_head?
outputs.append(x_node)
else:
raise NotImplementedError(
"Head type {} not recognized".format(type_head)
)
if self.var_output:
return outputs, outputs_var
return outputs
Expand Down
2 changes: 0 additions & 2 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
max_neighbours: Optional[int] = None,
**kwargs,
):

self.edge_dim = (
0 if edge_attr_dim is None else edge_attr_dim
) # Must be named edge_dim to trigger use by Base
Expand Down Expand Up @@ -159,7 +158,6 @@ def __init__(
self.clamp = clamp

if self.equivariant:

layer = nn.Linear(hidden_channels, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

Expand Down
67 changes: 67 additions & 0 deletions hydragnn/models/HybridEGCLStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################
from typing import Optional

import torch
import torch.nn as nn
from torch_geometric.nn import Sequential
import torch.nn.functional as F
from .EGCLStack import EGCLStack

from hydragnn.utils.model import unsorted_segment_mean


class HybridEGCLStack(EGCLStack):
def __init__(
self,
*args,
**kwargs,
):
# Initialize the parent class
super().__init__(*args, **kwargs)

# Define new loss functions
self.cross_entropy = torch.nn.CrossEntropyLoss()
self.mse = torch.nn.MSELoss()

def loss_hpweighted(self, pred, value, head_index, var=None):
"""
Overwrite this method to make split loss between
MSE (atom pos) and Cross Entropy (atom types).
"""

# weights for different tasks as hyper-parameters
tot_loss = 0
tasks_loss = []
for ihead in range(self.num_heads):
head_pred = pred[ihead]
pred_shape = head_pred.shape
head_val = value[head_index[ihead]]
value_shape = head_val.shape
if pred_shape != value_shape:
head_val = torch.reshape(head_val, pred_shape)

# Calculate loss depending on head
# Calculate cross entropy if atom types
if ihead == 0:
head_loss = self.cross_entropy(head_pred, head_val)
# Calculate MSE if position noise
elif ihead == 1:
head_loss = self.mse(head_pred, head_val)

# Add loss to total loss and list of tasks loss
tot_loss += head_loss * self.loss_weights[ihead]
tasks_loss.append(head_loss)

return tot_loss, tasks_loss

def __str__(self):
return "HybridEGCLStack"
20 changes: 19 additions & 1 deletion hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from hydragnn.models.SCFStack import SCFStack
from hydragnn.models.DIMEStack import DIMEStack
from hydragnn.models.EGCLStack import EGCLStack
from hydragnn.models.HybridEGCLStack import HybridEGCLStack
from hydragnn.models.PNAEqStack import PNAEqStack
from hydragnn.models.PAINNStack import PAINNStack
from hydragnn.models.MACEStack import MACEStack
Expand Down Expand Up @@ -345,7 +346,24 @@ def create_model(
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

elif model_type == "HybridEGNN":
model = HybridEGCLStack(
edge_dim,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
activation_function,
loss_function_type,
equivariance,
max_neighbours=max_neighbours,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)
elif model_type == "PAINN":
model = PAINNStack(
# edge_dim, # To-do add edge_features
Expand Down
7 changes: 7 additions & 0 deletions hydragnn/preprocess/graph_samples_checks_and_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .dataset_descriptors import AtomFeatures


## This function can be slow if datasets is too large. Use with caution.
## Recommend to use check_if_graph_size_variable_dist
def check_if_graph_size_variable(train_loader, val_loader, test_loader):
Expand Down Expand Up @@ -271,6 +272,12 @@ def update_predicted_values(
],
(-1, 1),
)
elif type[item] == "pos":
# index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.pos[:, : node_feature_dim[index[item]]],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
Expand Down
16 changes: 10 additions & 6 deletions hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_equivariance(config):
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"]
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE", "HybridEGNN"]
if "equivariance" in config and config["equivariance"]:
assert (
config["model_type"] in equivariant_models
Expand Down Expand Up @@ -188,7 +188,7 @@ def update_config_NN_outputs(config, data, graph_size_variable):
for ihead in range(len(output_type)):
if output_type[ihead] == "graph":
dim_item = data.y_loc[0, ihead + 1].item() - data.y_loc[0, ihead].item()
elif output_type[ihead] == "node":
elif output_type[ihead] == "node" or output_type[ihead] == "pos":
if (
graph_size_variable
and config["Architecture"]["output_heads"]["node"]["type"]
Expand All @@ -206,10 +206,14 @@ def update_config_NN_outputs(config, data, graph_size_variable):
else:
for ihead in range(len(output_type)):
if output_type[ihead] != "graph":
raise ValueError(
"y_loc is needed for outputs that are not at graph levels",
output_type[ihead],
)
if not "dynamic_target" in config["Variables_of_interest"] or (
"dynamic_target" in config["Variables_of_interest"]
and not config["Variables_of_interest"]["dynamic_target"]
): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing
raise ValueError(
"y_loc is needed for outputs that are not at graph levels",
output_type[ihead],
)
dims_list = config["Variables_of_interest"]["output_dim"]

config["Architecture"]["output_dim"] = dims_list
Expand Down
2 changes: 2 additions & 0 deletions hydragnn/utils/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def activation_function_selection(activation_function_string: str):
return torch.nn.ReLU()
elif activation_function_string == "selu":
return torch.nn.SELU()
elif activation_function_string == "silu":
return torch.nn.SiLU()
elif activation_function_string == "prelu":
return torch.nn.PReLU()
elif activation_function_string == "elu":
Expand Down
Loading