-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversion of mobilevit to Keras 3 with tf.keras backwards compatibility #18827
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
|
||
def conv_block(x, filters=16, kernel_size=3, strides=2): | ||
conv_layer = layers.Conv2D( | ||
filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use "swish"
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m) | ||
m = layers.BatchNormalization()(m) | ||
|
||
if tf.math.equal(x.shape[-1], output_channels) and strides == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use if x.shape[-1] == output_channels
|
||
import tensorflow as tf | ||
|
||
# For versions <TF2.13 change the above import to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these 2 lines
models with the same or higher complexity ([MobileNetV3](https://arxiv.org/abs/1905.02244), | ||
for example), while being efficient on mobile devices. | ||
|
||
Note: This example should be run with Tensorflow 2.13 and higher. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove outdated line
inference with TFLite models, check out | ||
[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization). | ||
|
||
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these 2 outdated lines
tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops. | ||
] | ||
tflite_model = converter.convert() | ||
open("mobilevit_xxs.tflite", "wb").write(tflite_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check that the TFLite model works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't perform the inference using TFLite model, but I verified that this line is saving TFLite model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try using the TFLite interpreter to check that it works as expected? https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python
@fchollet Completed all the changes asked by you. Also, I checked converted TFLite model is working as expected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
I'm not able to render the example as a notebook, the kernel dies at the line that writes the serialized model (even with 60GB of RAM). Are you able to run |
I couldn't find |
@fchollet Since I have made changes to work with tensorflow backend and it's merged to be available in |
Conversion of mobilevit.py to Keras 3 with tensorflow backend as per stage 1 of Issue: Keras.io examples conversion gameplan #18468.
Provided solutions of the all problems raised by me in Issue #18613 .