From 13614bf2577055ce75d01ff7502cf821e95c7b6d Mon Sep 17 00:00:00 2001 From: dsethz Date: Thu, 31 Oct 2024 17:54:43 +0100 Subject: [PATCH] switched to logit loss for mixed precision training --- src/nuclai/models/mlp.py | 41 ++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/nuclai/models/mlp.py b/src/nuclai/models/mlp.py index 021258e..12c7639 100644 --- a/src/nuclai/models/mlp.py +++ b/src/nuclai/models/mlp.py @@ -90,15 +90,13 @@ def __init__( dim_out = dimensions[i + 1] self.net.append(nn.Linear(dim_in, dim_out, bias=self.bias)) - if i == (len(dimensions) - 2): - self.net.append(nn.Sigmoid()) - else: - self.net.append( - nn.LeakyReLU(negative_slope=1e-2, inplace=True) - ) - self.net.append(nn.BatchNorm1d(dimensions[i + 1])) - if self.dropout > 0.0: - self.net.append(nn.Dropout(p=self.dropout)) + # remove sigmoid due to numerical instability with mixed-precision + # if i == (len(dimensions) - 2): + # self.net.append(nn.Sigmoid()) + self.net.append(nn.LeakyReLU(negative_slope=1e-2, inplace=True)) + self.net.append(nn.BatchNorm1d(dimensions[i + 1])) + if self.dropout > 0.0: + self.net.append(nn.Dropout(p=self.dropout)) # add the list of modules to the current module self.net = nn.ModuleList(self.net) @@ -195,7 +193,7 @@ def training_step( # forward pass prediction = self.net(features) - loss = F.binary_cross_entropy( + loss = F.binary_cross_entropy_with_logits( prediction, labels, weight=self.loss_weight ) @@ -211,6 +209,9 @@ def training_step( def on_validation_start(self): self.f1_val = F1Score(task="binary", threshold=0.5) + # get sigmoid for predictions + self.sigmoid = nn.Sigmoid() + def validation_step( self, batch: list[torch.Tensor], batch_idx: int ) -> torch.Tensor: @@ -219,12 +220,15 @@ def validation_step( prediction = self.net(features) # log loss - loss = F.binary_cross_entropy( + loss = F.binary_cross_entropy_with_logits( prediction, labels, weight=self.loss_weight ) self.log("loss_val", loss, on_step=True, on_epoch=True, sync_dist=True) + # convert prediction to probability for F1 + prediction = self.sigmoid(prediction) + # update f1_val self.f1_val.update( torch.squeeze(prediction, -1), torch.squeeze(labels, -1) @@ -245,13 +249,16 @@ def on_test_start(self): # set up f1 metric self.f1_test = F1Score(task="binary", threshold=0.5) + # get sigmoid for predictions + self.sigmoid = nn.Sigmoid() + def test_step( self, batch: list[torch.Tensor], batch_idx: int ) -> torch.Tensor: features, labels, ids = batch prediction = self.net(features) - loss = F.binary_cross_entropy( + loss = F.binary_cross_entropy_with_logits( prediction, labels, weight=self.loss_weight ) @@ -264,6 +271,9 @@ def test_step( sync_dist=False, ) + # convert prediction to probability + prediction = self.sigmoid(prediction) + # update f1_test self.f1_test.update( torch.squeeze(prediction, -1), torch.squeeze(labels, -1) @@ -273,6 +283,7 @@ def test_step( idx = os.path.basename( self.trainer.datamodule.data_test.data[str(ids[0].item())]["path"] ) + tmp = pd.DataFrame( { "id": [idx], @@ -299,6 +310,9 @@ def on_predict_start(self): # create prediction data frame to which all predictions are appended self.prediction_data = pd.DataFrame(columns=["id", "prediction"]) + # get sigmoid for predictions + self.sigmoid = nn.Sigmoid() + def predict_step( self, batch: list[torch.Tensor], batch_idx: int ) -> torch.Tensor: @@ -306,6 +320,9 @@ def predict_step( prediction = self.net(features) + # convert prediction to probability + prediction = self.sigmoid(prediction) + # save predictions and labels idx = os.path.basename( self.trainer.datamodule.data_predict.data[str(ids[0].item())][