Skip to content

Commit

Permalink
Improve numerical stability of softmax before AUROC by force-casting …
Browse files Browse the repository at this point in the history
…to float32
  • Loading branch information
nathanpainchaud committed Jul 5, 2024
1 parent 80ae047 commit 452d72e
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions didactic/data/cardinal/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,12 @@ def _write_prediction_scores(

# Convert the classification logits/probabilities to numpy arrays
for attr, attr_pred in classification_out.items():
attr_pred = np.array(attr_pred)
# Convert to numpy array and ensure float32, to avoid numerical instabilities in case of float16 values
# coming from AMP models. This is especially important for softmax, which is sensitive to small values.
attr_pred = np.array(attr_pred, dtype=np.float32)
if (attr_pred < 0).any() or (attr_pred > 1).any():
# If output were logits, compute probabilities from logits
attr_pred = softmax(attr_pred, axis=1)
# Rescale the output of the softmax to make sure that it sums to 1
# In theory this should be handled by the softmax function itself, but scipy's implementation
# returned values summing to 0.9995, which was not close enough to 1 for some downstream cases
# Also, wrap the scaling inside a while loop to handle numerical instability, since a one-time
# scaling was not always enough to get close enough to a sum of 1
while not np.allclose(1, attr_pred.sum(axis=1)):
attr_pred /= attr_pred.sum(axis=1, keepdims=True)
classification_out[attr] = attr_pred

if subset_categorical_data:
Expand Down

0 comments on commit 452d72e

Please sign in to comment.