Skip to content

Commit

Permalink
Merge pull request #3 from aidos-lab/mantra-dev-er
Browse files Browse the repository at this point in the history
Generalized function signature for models.
  • Loading branch information
danielbinschmid authored Jun 14, 2024
2 parents 1610c72 + 6a4577d commit 4b204a8
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
3 changes: 2 additions & 1 deletion models/GAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, config: GATConfig):
config.hidden_channels, config.out_channels
)

def forward(self, x, edge_index, batch):
def forward(self, batch):
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
# 1. Obtain node embeddings
x = self.gat_input(x, edge_index)
for layer in self.hidden_layers:
Expand Down
3 changes: 2 additions & 1 deletion models/GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, config: GCNConfig):
config.hidden_channels, config.out_channels
)

def forward(self, x, edge_index, batch):
def forward(self, batch):
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
# 1. Obtain node embeddings
x = self.conv_input(x, edge_index)
for layer in self.hidden_layers:
Expand Down
5 changes: 3 additions & 2 deletions models/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def __init__(
config.num_hidden_neurons, config.out_channels
)

def forward(self, x, edge_index, signal_belongings):
def forward(self, batch):
x, _, batch = batch.x, batch.edge_index, batch.batch
x = self.input_layer(x)
x = nn.functional.relu(x)
for hidden_layer in self.hidden_layers:
x = hidden_layer(x)
x = nn.functional.relu(x)
x = self.output_layer(x)
return pool.global_mean_pool(x, signal_belongings)
return pool.global_mean_pool(x, batch)
3 changes: 2 additions & 1 deletion models/TAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def __init__(
config.hidden_channels, config.out_channels
)

def forward(self, x, edge_index, batch):
def forward(self, batch):
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
# 1. Obtain node embeddings
x = self.conv_input(x, edge_index)
for layer in self.hidden_layers:
Expand Down
3 changes: 2 additions & 1 deletion models/TransfConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
config.hidden_channels, config.out_channels
)

def forward(self, x, edge_index, batch):
def forward(self, batch):
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
# 1. Obtain node embeddings
x = self.conv_input(x, edge_index)
for layer in self.hidden_layers:
Expand Down
17 changes: 13 additions & 4 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@ def __init__(
self.model = model
self.learning_rate = learning_rate

def forward(self, x, edge_index, batch):
x = self.model(x, edge_index, batch)
def forward(self, batch):
x = self.model(batch)
return x

def general_step(self, batch, batch_idx, step: str):
batch_len = len(batch.y)
x_hat = self(batch.x, batch.edge_index, batch.batch)

# This is rather ugly, open to better solutions,
# but torch_geometric and the toponetx dl have rather different
# signatures.
if hasattr(batch, "batch"):
batch_len = batch.batch.max() + 1
else:
batch_len = batch[-1]

# Generalizing to accomodate for the different signatures.
x_hat = self(batch)
# Squeeze x_hat to match the shape of y
x_hat = x_hat.squeeze()
loss = self.loss_fn(x_hat, batch.y)
Expand Down

0 comments on commit 4b204a8

Please sign in to comment.