Skip to content

Commit

Permalink
switched to logit loss for mixed precision training
Browse files Browse the repository at this point in the history
  • Loading branch information
dsethz committed Oct 31, 2024
1 parent 6d8e223 commit 13614bf
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions src/nuclai/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def __init__(
dim_out = dimensions[i + 1]
self.net.append(nn.Linear(dim_in, dim_out, bias=self.bias))

if i == (len(dimensions) - 2):
self.net.append(nn.Sigmoid())
else:
self.net.append(
nn.LeakyReLU(negative_slope=1e-2, inplace=True)
)
self.net.append(nn.BatchNorm1d(dimensions[i + 1]))
if self.dropout > 0.0:
self.net.append(nn.Dropout(p=self.dropout))
# remove sigmoid due to numerical instability with mixed-precision
# if i == (len(dimensions) - 2):
# self.net.append(nn.Sigmoid())
self.net.append(nn.LeakyReLU(negative_slope=1e-2, inplace=True))
self.net.append(nn.BatchNorm1d(dimensions[i + 1]))
if self.dropout > 0.0:
self.net.append(nn.Dropout(p=self.dropout))

# add the list of modules to the current module
self.net = nn.ModuleList(self.net)
Expand Down Expand Up @@ -195,7 +193,7 @@ def training_step(

# forward pass
prediction = self.net(features)
loss = F.binary_cross_entropy(
loss = F.binary_cross_entropy_with_logits(
prediction, labels, weight=self.loss_weight
)

Expand All @@ -211,6 +209,9 @@ def training_step(
def on_validation_start(self):
self.f1_val = F1Score(task="binary", threshold=0.5)

# get sigmoid for predictions
self.sigmoid = nn.Sigmoid()

def validation_step(
self, batch: list[torch.Tensor], batch_idx: int
) -> torch.Tensor:
Expand All @@ -219,12 +220,15 @@ def validation_step(
prediction = self.net(features)

# log loss
loss = F.binary_cross_entropy(
loss = F.binary_cross_entropy_with_logits(
prediction, labels, weight=self.loss_weight
)

self.log("loss_val", loss, on_step=True, on_epoch=True, sync_dist=True)

# convert prediction to probability for F1
prediction = self.sigmoid(prediction)

# update f1_val
self.f1_val.update(
torch.squeeze(prediction, -1), torch.squeeze(labels, -1)
Expand All @@ -245,13 +249,16 @@ def on_test_start(self):
# set up f1 metric
self.f1_test = F1Score(task="binary", threshold=0.5)

# get sigmoid for predictions
self.sigmoid = nn.Sigmoid()

def test_step(
self, batch: list[torch.Tensor], batch_idx: int
) -> torch.Tensor:
features, labels, ids = batch

prediction = self.net(features)
loss = F.binary_cross_entropy(
loss = F.binary_cross_entropy_with_logits(
prediction, labels, weight=self.loss_weight
)

Expand All @@ -264,6 +271,9 @@ def test_step(
sync_dist=False,
)

# convert prediction to probability
prediction = self.sigmoid(prediction)

# update f1_test
self.f1_test.update(
torch.squeeze(prediction, -1), torch.squeeze(labels, -1)
Expand All @@ -273,6 +283,7 @@ def test_step(
idx = os.path.basename(
self.trainer.datamodule.data_test.data[str(ids[0].item())]["path"]
)

tmp = pd.DataFrame(
{
"id": [idx],
Expand All @@ -299,13 +310,19 @@ def on_predict_start(self):
# create prediction data frame to which all predictions are appended
self.prediction_data = pd.DataFrame(columns=["id", "prediction"])

# get sigmoid for predictions
self.sigmoid = nn.Sigmoid()

def predict_step(
self, batch: list[torch.Tensor], batch_idx: int
) -> torch.Tensor:
features, ids = batch

prediction = self.net(features)

# convert prediction to probability
prediction = self.sigmoid(prediction)

# save predictions and labels
idx = os.path.basename(
self.trainer.datamodule.data_predict.data[str(ids[0].item())][
Expand Down

0 comments on commit 13614bf

Please sign in to comment.