This repository contains implementation of continual learning (CL) framework and Train-On-Request (TOR) workflow using PyTorch and on-device deployment of CL on the ultra-low power GAP9 microcontroller.
Please cite the following publication if you use our implementation of CL or on-device deployment:
@misc{mei2024BMIODCL,
title={An Ultra-Low Power Wearable BMI System with Continual Learning Capabilities},
author={Lan Mei, Thorir Mar Ingolfsson, Cristian Cioflan, Victor Kartsch, Andrea Cossettini, Xiaying Wang, Luca Benini},
journal={IEEE Transactions on Biomedical Circuits and Systems}
year={2024},
doi={10.1109/TBCAS.2024.3457522}
}
Please cite the following publication if you use our TOR workflow:
@misc{mei2024trainonrequestondevicecontinuallearning,
title={Train-On-Request: An On-Device Continual Learning Workflow for Adaptive Real-World Brain Machine Interfaces},
author={Lan Mei, Cristian Cioflan, Thorir Mar Ingolfsson, Victor Kartsch, Andrea Cossettini, Xiaying Wang, Luca Benini},
year={2024},
eprint={2409.09161},
archivePrefix={arXiv},
primaryClass={eess.SP},
url={https://arxiv.org/abs/2409.09161},
}
This is an environment derived from QuantLab based on PyTorch 1.13.1. To install the prerequisites, create a conda environment with:
# PyTorch 1.13.1 (Recommended)
$> conda create --name pytorch-1.13
$> conda activate pytorch-1.13
$> conda config --env --add channels conda-forge
$> conda config --env --add channels pytorch
$> conda install python=3.8 pytorch=1.13.1 pytorch-gpu torchvision=0.14.1 torchtext=0.14.1 torchaudio=0.13.1 cudatoolkit=11.6 -c pytorch -c conda-forge
$> conda install ipython packaging parse setuptools tensorboard tqdm networkx python-graphviz scipy pandas ipdb onnx onnxruntime einops yapf tabulate
$> pip install setuptools==59.5.0 torchsummary parse coloredlogs netron
For converting quantized networks from QuantLab to codes for on-device deployments with DORY, the quantlib quantization library is needed in the conda environment:
$ cd Offline-Training/quantlab-cl
$ pip install -e quantlib
The current quantlib
in this repository is based on commit:ba13b4957bd23c54d94b5aae78457b78341e76bf
with modifications on the quantization workflow to export our models correctly.
For generating and running on-device implementation:
Install GAP9-SDK under BMI-ODCL/
as indicated in the associated repository. Access to this repository can be granted by GreenWaves Technologies,
Install Dory under BMI-ODCL/Offline-Training/
as indicated in the associated repository. Check branch_id.log to match our working branch. Replace dory/Parsers/Parser_ONNX_to_DORY.py
with the file already present in the current repository under BMI-ODCL/Offline-Training/dory/Parser_ONNX_to_DORY.py
.
Install PULP-Trainlib under BMI-ODCL/On-Device-Implementation/
as indicated in the associated repository. The current repository was tested with commit 6615e084738958890ea9dd10195f8bbfe089ceb7
.
This work uses two in-house EEG datasets for BMI. The datasets can be downloaded from this link: https://iis-people.ee.ethz.ch/~datasets/Datasets-ODCL/.
- Dataset A: An in-house EEG MM dataset for classifying left hand and right hand movements. This dataset contains seven data sessions from one subject. The stored data files are csv files.
- Dataset B: An in-house EEG MM/MI dataset for classifying left hand, right hand, tongue, and rest. This dataset contains data from five subjects and four sessions for each subject. The stored data files are binary files, which can be converted to csv files with
run_conversion.m
fromBMI-ODCL/Preprocessing
. The lists of file paths to be converted can be modified or added inrun_conversion.m
. The corresponding csv files will be stored in a newly created folder:DatasetB/SubjectX_XXXX_SX/MM/csv/
.
Note that only csv files will be used in classification and the conversion of all files in Dataset B should be treated as a preprocessing step.
This repository contains five folders:
- Preprocessing: Preprocessing codes for converting binary files of Dataset B to csv files.
- Avalanche-Implementation-CL: Implementation of CL algorithms on Dataset A using Avalanche in Python.
- Offline-Training: Offline implementations of within-session classification, transfer learning (TL) and CL workflow, and quantization using QuantLab in Python.
- On-Device-Implementation: Implementation of on-device TL/CL on GAP9.
- BMI-TOR: Implementation of the Train-On-Request (TOR) workflow with continual learning capabilities.
Detailed descriptions and instructions of each component can be found in their respective README files.
- Lan Mei, ETH Zurich [email protected]
- Thorir Mar Ingolfsson, ETH Zurich [email protected]
- Cristian Cioflan, ETH Zurich [email protected]
- Victor Kartsch, ETH Zurich [email protected]
- Andrea Cossettini, ETH Zurich [email protected]
- Xiaying Wang, ETH Zurich [email protected]
Unless explicitly stated otherwise, the code is released under Apache 2.0. Please see the LICENSE file in the root of this repository for details.
As an exception, the weights:
./On-Device-Implementation/Backbone-Example/model.onnx
./On-Device-Implementation/Backbone-Example/hex/*weights.hex
./On-Device-Implementation/Classifier-Example/linear-data.h
./On-Device-Implementation/Classifier-Example/weights_fc.npy
and./On-Device-Implementation/Classifier-Example/bias_fc.npy
./On-Device-Implementation/Classifier-Example-LwF/linear-data.h
./On-Device-Implementation/Classifier-Example-LwF/weights_fc.npy
and./On-Device-Implementation/Classifier-Example-LwF/bias_fc.npy
and the inputs:
./On-Device-Implementation/Backbone-Example/hex/*inputs.hex
./On-Device-Implementation/Classifier-Example/inputs/
./On-Device-Implementation/Classifier-Example-LwF/inputs/
are released under Creative Commons Attribution-NoDerivatives 4.0 International. Please see the LICENSE file in their respective directories.
Note that the license under which the current repository is released might differ from the license of each individual package:
- Avalanche - MIT License;
- PyTorch - a mix of licenses, including the Apache 2.0 License and the 3-Clause BSD License;
- TensorBoard - Apache 2.0 License;
- NetworkX - 3-Clause BSD License;
- GraphViz - MIT License;
- matplotlib - a custom license;
- NumPy - 3-Clause BSD License;
- SciPy - 3-Clause BSD License;
- Mako - MIT License;
- Jupyter - 3-Clause BSD License;
- Pandas - 3-Clause BSD License;
- early-stopping-pytorch - MIT License.