Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Distillation with a Chunked, Fused Linear JS-divergence Loss #408

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

austin362667
Copy link
Contributor

Summary

Knowledge Distillation

Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.

In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let z_t and z_s represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature T. When ground truth labels y are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.

The combined loss function is defined as:

$$\mathcal{L} = \mathcal{L}_{\text{distill}}(\text{softmax}(\mathbf{z_t}, T), \text{softmax}(\mathbf{z_s}, T)) + \lambda \mathcal{L}_{CE}(\mathbf{y}, \mathbf{z_s}),$$

Here, lambda is a hyperparameter that balances the distillation loss and the supervised objective.

Shared DistillationBase

To support various distillation learning objectives, this PR aims to add a LigerFusedLinearDistillationBase which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.

Jensen-Shannon Divergence Loss

In addition to adding the base class, this PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in the distillation setting. This component can be replaced with other losses (e.g., KL divergence) as distillation_loss_fn.

JSD is defined as the average of the KL divergences between each distribution and the mean distribution:

$$\text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q)$$

Here, Pand Q are the two probability distributions, and M is their average.

TODO

  • Investigate why the chunked implementation is so slow compared to the naive approach.
  • Integrate temperature scaling.

Testing Done

Yes.

jsd_loss_memory

jsd_loss_speed

  • Hardware Type: A100 40G
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Austin Liu <[email protected]>

Add Testing Naive Distillation Base

Signed-off-by: Austin Liu <[email protected]>

Add Chunked JSD Tests and Benchmarks

Signed-off-by: Austin Liu <[email protected]>

Fix call

Signed-off-by: Austin Liu <[email protected]>

Fix Test Usage

Signed-off-by: Austin Liu <[email protected]>

Remove beta

Signed-off-by: Austin Liu <[email protected]>

Fix test params

Signed-off-by: Austin Liu <[email protected]>

Fix call

Signed-off-by: Austin Liu <[email protected]>

Fix ignore_index

Signed-off-by: Austin Liu <[email protected]>

Fix weights dimension

Signed-off-by: Austin Liu <[email protected]>

Fix assign dimension

Signed-off-by: Austin Liu <[email protected]>

Fix assign dimension

Signed-off-by: Austin Liu <[email protected]>

Fix teacher bias

Signed-off-by: Austin Liu <[email protected]>

Reshape input

Signed-off-by: Austin Liu <[email protected]>

Fix mean

Signed-off-by: Austin Liu <[email protected]>

Remove alpha

Signed-off-by: Austin Liu <[email protected]>

Fix t

Signed-off-by: Austin Liu <[email protected]>

Fix t

Signed-off-by: Austin Liu <[email protected]>

Fix t scaling

Signed-off-by: Austin Liu <[email protected]>

Remove teacher tests

Signed-off-by: Austin Liu <[email protected]>

Fix t scaling

Signed-off-by: Austin Liu <[email protected]>

Fix beta

Signed-off-by: Austin Liu <[email protected]>

Fix beta

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>

WIP

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>

Clean up

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Format

Signed-off-by: Austin Liu <[email protected]>

Fix

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>

Fix tol

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@austin362667 austin362667 changed the title Add Support for Knowledge Distillation with a chunked, fused linear JS-divergence Loss Add Distillation with a Chunked, Fused Linear JS-divergence Loss Nov 27, 2024
Signed-off-by: Austin Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant