Skip to content

Training and distributing models

André Pedersen edited this page Jan 24, 2023 · 1 revision

The README contains a lot of information on how to train models (see here), but for completeness, we will repeat some of that here.

Disclaimer

Currently, the framework has been tailored against a single use case (Smart Watch Gestures classification). Hence, in order to use this on a very different use case, modifications will need to be made. However, the goal is to make the framework more generic in the future to make it more easy to use on custom use cases.

Setup

When using this framework, it is a good idea to setup a virtual environment:

virtualenv -ppython3 venv --clear
`./venv/Scripts/activate`
pip install -r requirements.txt

Tested with Python 3.7.9 and Win10 operating system.

Usage

To train a model, simply run:

python main.py

The script supports multiple arguments. To see supported arguments, run python main.py -h.

Training history

To visualize training history, use TensorBoard (with example):

tensorboard --logdir ./output/logs/gesture_classifier_arch_rnn/

Example of training history for a Recurrent Neural Network (RNN) can be seen underneath:

The figure shows macro-averaged F1-score for each step during training, with black curve for training and blue curve for validation sets. Best model reached a macro-averaged F1 score of 99.66 % on the validation set, across all 20 classes.

Converting model to TF-Lite

In order to use the pretrained model in a mobile device, the model need to be converting to a compatible format.

TensorFlow Lite, or TF-Lite for short (or .tflite), is the most common format for storing models. This is also because TensorFlow has made a specific inference engine called that is tailored for mobile devices (also called TF-Lite).

To convert the model, simply run:

python dss/keras2tflite.py -m /path/to/pretrained/saved_model/ -o /path/to/save/converted/model.tflite

Using TF-Lite model with Flutter app

This can easily be done by adding the model to the sw_app/assets/ directory (see here). In this case there is already a model.tflite file, but this can be removed if a new model is added.

However, for Flutter to know that this model exist, it will need to know where it lies (and what the name is, which may differ).

This can be done by replacing the model that already exist in the pubspec.yaml file, with the appropriate model, e.g., - assets/custom_cooler_model.tflite.