Skip to content

Implementation of some unbalanced loss like focal_loss, dice_loss, DSC Loss, GHM Loss et.al

License

Notifications You must be signed in to change notification settings

549385454/NLP-Loss-Pytorch

 
 

Repository files navigation

Implementation of some unbalanced loss for NLP task like focal_loss, dice_loss, DSC Loss, GHM Loss et.al and adversarial training like FGM, FGSM, PGD, FreeAT.

Loss Summary

Here is a loss implementation repository included unbalanced loss

Loss Name paper Notes
Weighted CE Loss UNet Architectures in Multiplanar Volumetric Segmentation -- Validated on Three Knee MRI Cohorts
Focal Loss Focal Loss for Dense Object Detection
Dice Loss V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
DSC Loss Dice Loss for Data-imbalanced NLP Tasks
GHM Loss Gradient Harmonized Single-stage Detector

How to use?

You can find all the loss usage information in test_loss.py.

Here is a simple demo of usage:

import torch
from unbalanced_loss.focal_loss import MultiFocalLoss

batch_size, num_class = 64, 10
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')

logits = torch.rand(batch_size, num_class, requires_grad=True)  # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size,))  # (batch_size, )

loss = Loss_Func(logits, targets)
loss.backward()

Adversarial Training Summary

Here is a Summary of Adversarial Training implementation.
you can find more details in adversarial_training/README.md

Adversarial Training paper Notes
FGM Fast Gradient Method
FGSM Fast Gradient Sign Method
PGD Towards Deep Learning Models Resistant to Adversarial Attacks
FreeAT Free Adversarial Training
FreeLB Free Large Batch Adversarial Training

How to use?

You can find a simple demo for bert classification in test_bert.py.

Here is a simple demo of usage:
You just need to rewrite train function according to input for your model in file PGD.py, then you can use adversarial training like below.

import transformers
from model import bert_classification
from adversarial_training.PGD import PGD

batch_size, num_class = 64, 10
# model = your_model()
model = bert_classification()
AT_Model = PGD(model)
optimizer = transformers.AdamW(model.parameters(), lr=0.001)

# rewrite your train function in pgd.py
outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer)

Adversarial Training Results Compare

Adversarial Training Time Cost(s/epoch ) best_acc
Normal(not add attack) 23.77 0.773
FGSM 45.95 0.7936
FGM 47.28 0.8008
PGD(k=3) 87.50 0.7963
FreeAT(k=3) 93.26 0.7896

About

Implementation of some unbalanced loss like focal_loss, dice_loss, DSC Loss, GHM Loss et.al

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%