Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Keras 3: Knowledge Distillation example #18493

Merged
merged 7 commits into from
Sep 26, 2023

Conversation

awsaf49
Copy link
Contributor

@awsaf49 awsaf49 commented Sep 25, 2023

This PR will add Knowledge Distillation example from keras.io/examples. Sadly, this example is not backend agnostic, presumably due to tf.GradientTape() which is TensorFlow specific.

@awsaf49
Copy link
Contributor Author

awsaf49 commented Sep 25, 2023

@codecov-commenter
Copy link

codecov-commenter commented Sep 25, 2023

Codecov Report

All modified lines are covered by tests ✅

Comparison is base (29a954a) 77.40% compared to head (01b3701) 73.77%.
Report is 13 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18493      +/-   ##
==========================================
- Coverage   77.40%   73.77%   -3.64%     
==========================================
  Files         331      331              
  Lines       31972    31984      +12     
  Branches     6241     6246       +5     
==========================================
- Hits        24749    23596    -1153     
- Misses       5646     6857    +1211     
+ Partials     1577     1531      -46     
Flag Coverage Δ
keras 73.70% <ø> (-3.62%) ⬇️
keras-jax ?
keras-numpy 56.15% <ø> (-0.02%) ⬇️
keras-tensorflow 62.03% <ø> (-0.15%) ⬇️
keras-torch 63.96% <ø> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

see 28 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Collaborator

Formatted diff:

--- a/examples/keras_io/tensorflow/vision/knowledge_distillation.py
+++ b/examples/keras_io/tensorflow/vision/knowledge_distillation.py
@@ -1,12 +1,11 @@
 """
 Title: Knowledge Distillation
 Author: [Kenneth Borup](https://twitter.com/Kennethborup)
+Converted to Keras 3: [Md Awsafur Rahman](https://awsaf49.github.io)
 Date created: 2020/09/01
 Last modified: 2020/09/01
 Description: Implementation of classical Knowledge Distillation.
-Accelerator: GPU
 """
-
 """
 ## Introduction to Knowledge Distillation
 
@@ -29,12 +28,16 @@
 ## Setup
 """
 
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+from keras import ops
 import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
 import numpy as np
 
-
 """
 ## Construct `Distiller()` class
 
@@ -50,8 +53,10 @@
 - An optimizer for the student and (optional) metrics to evaluate performance
 
 In the `train_step` method, we perform a forward pass of both the teacher and student,
-calculate the loss with weighting of the `student_loss` and `distillation_loss` by `alpha` and
-`1 - alpha`, respectively, and perform the backward pass. Note: only the student weights are updated,
+calculate the loss with weighting of the `student_loss` and `distillation_loss` by
+`alpha` and
+`1 - alpha`, respectively, and perform the backward pass. Note: only the student weights
+are updated,
 and therefore we only calculate the gradients for the student weights.
 
 In the `test_step` method, we evaluate the student model on the provided dataset.
@@ -111,8 +116,8 @@ def train_step(self, data):
             # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
             distillation_loss = (
                 self.distillation_loss_fn(
-                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
-                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
+                    ops.softmax(teacher_predictions / self.temperature, axis=1),
+                    ops.softmax(student_predictions / self.temperature, axis=1),
                 )
                 * self.temperature**2
             )
@@ -168,7 +173,7 @@ def test_step(self, data):
     [
         keras.Input(shape=(28, 28, 1)),
         layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
-        layers.LeakyReLU(alpha=0.2),
+        layers.LeakyReLU(negative_slope=0.2),
         layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
         layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
         layers.Flatten(),
@@ -182,7 +187,7 @@ def test_step(self, data):
     [
         keras.Input(shape=(28, 28, 1)),
         layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
-        layers.LeakyReLU(alpha=0.2),
+        layers.LeakyReLU(negative_slope=0.2),
         layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
         layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
         layers.Flatten(),
@@ -198,7 +203,8 @@ def test_step(self, data):
 ## Prepare the dataset
 
 The dataset used for training the teacher and distilling the teacher is
-[MNIST](https://keras.io/api/datasets/mnist/), and the procedure would be equivalent for any other
+[MNIST](https://keras.io/api/datasets/mnist/), and the procedure would be equivalent for
+any other
 dataset, e.g. [CIFAR-10](https://keras.io/api/datasets/cifar10/), with a suitable choice
 of models. Both the student and teacher are trained on the training set and evaluated on
 the test set.
@@ -284,4 +290,4 @@ def test_step(self, data):
 You should expect the teacher to have accuracy around 97.6%, the student trained from
 scratch should be around 97.6%, and the distilled student should be around 98.1%. Remove
 or try out different seeds to use different weight initializations.
-"""
\ No newline at end of file

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Sep 26, 2023
@fchollet fchollet merged commit 0484165 into keras-team:master Sep 26, 2023
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Sep 26, 2023
@awsaf49 awsaf49 deleted the ex_kd branch September 26, 2023 11:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants