-
Notifications
You must be signed in to change notification settings - Fork 4
Training and distributing models
The README contains a lot of information on how to train models (see here), but for completeness, we will repeat some of that here.
Currently, the framework has been tailored against a single use case (Smart Watch Gestures classification). Hence, in order to use this on a very different use case, modifications will need to be made. However, the goal is to make the framework more generic in the future to make it more easy to use on custom use cases.
When using this framework, it is a good idea to setup a virtual environment:
virtualenv -ppython3 venv --clear
`./venv/Scripts/activate`
pip install -r requirements.txt
Tested with Python 3.7.9 and Win10 operating system.
To train a model, simply run:
python main.py
The script supports multiple arguments. To see supported arguments, run python main.py -h
.
To visualize training history, use TensorBoard (with example):
tensorboard --logdir ./output/logs/gesture_classifier_arch_rnn/
Example of training history for a Recurrent Neural Network (RNN) can be seen underneath:
The figure shows macro-averaged F1-score for each step during training, with black curve for training and blue curve for validation sets. Best model reached a macro-averaged F1 score of 99.66 % on the validation set, across all 20 classes.
In order to use the pretrained model in a mobile device, the model need to be converting to a compatible format.
TensorFlow Lite, or TF-Lite for short (or .tflite), is the most common format for storing models. This is also because TensorFlow has made a specific inference engine called that is tailored for mobile devices (also called TF-Lite).
To convert the model, simply run:
python dss/keras2tflite.py -m /path/to/pretrained/saved_model/ -o /path/to/save/converted/model.tflite
This can easily be done by adding the model to the sw_app/assets/
directory (see here). In this case there is already a model.tflite
file, but this can be removed if a new model is added.
However, for Flutter to know that this model exist, it will need to know where it lies (and what the name is, which may differ).
This can be done by replacing the model that already exist in the pubspec.yaml
file, with the appropriate model, e.g., - assets/custom_cooler_model.tflite
.