Skip to content

Commit

Permalink
Merge pull request #17 from florencejt/bugs/attentionGNNmulticlass
Browse files Browse the repository at this point in the history
Bug fix in AttentionWeightedGNN with multiclass classification
  • Loading branch information
florencejt authored Jan 12, 2024
2 parents b99d4de + 78cc82b commit 5bb530f
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions fusilli/fusionmodels/tabularfusion/attention_weighted_GNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def training_step(self, batch, batch_idx):

y_hat, weights = self.forward((x1, x2))

if self.prediction_task == "multiclass":
# turn the labels into one hot vectors
y = F.one_hot(y, num_classes=self.multiclass_dimensions).to(torch.float32)

loss = F.mse_loss(y_hat.squeeze(), y.to(torch.float32).squeeze())
self.log('train_loss', loss, logger=None)
return loss
Expand All @@ -186,6 +190,11 @@ def validation_step(self, batch, batch_idx):
"""
x1, x2, y = batch
y_hat, weights = self.forward((x1, x2))

if self.prediction_task == "multiclass":
# turn the labels into one hot vectors
y = F.one_hot(y, num_classes=self.multiclass_dimensions).to(torch.float32)

loss = F.mse_loss(y_hat.squeeze(), y.to(torch.float32).squeeze())
self.log('val_loss', loss, logger=None)
return loss
Expand Down

0 comments on commit 5bb530f

Please sign in to comment.