-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5778d1e
Showing
4 changed files
with
98 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
try.py | ||
venv | ||
.idea | ||
tinymlgen/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# TinyML gen | ||
|
||
This is a simple package to export a model trained in Tensorflow Lite | ||
to a plain C array, ready to be used for inference on microcontrollers. | ||
|
||
### Install | ||
|
||
```shell script | ||
pip install tinymlgen | ||
``` | ||
|
||
### Use | ||
|
||
```python | ||
from tinymlgen import port | ||
|
||
if __name__ == '__main__': | ||
tf_model = create_tf_model() | ||
c_code = port(tf_model) | ||
``` | ||
|
||
### Configuration | ||
|
||
You can pass a few parameters to the `port` function: | ||
|
||
- `optimize (=True)`: apply optimizers to the exported model. | ||
Can either be a list of optimizers or a boolean, in which case | ||
`OPTIMIZE_FOR_SIZE` is applied | ||
- `variable_name (='model_data')`: give the exported array a custom name | ||
- `pretty_print (=False)`: print the array in a nicely formatted arrangement | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from distutils.core import setup | ||
setup( | ||
name = 'tinymlgen', | ||
packages = ['tinymlgen'], | ||
version = '0.1', | ||
license='MIT', | ||
description = 'Generate C code for microcontrollers from Tensorflow models', | ||
author = 'Simone Salerno', | ||
author_email = '[email protected]', | ||
url = 'https://github.com/eloquentarduino/tinymlgen', | ||
download_url = 'https://github.com/eloquentarduino/tinymlgen/archive/v_01.tar.gz', | ||
keywords = ['ML', 'microcontrollers', 'tensorflow', 'machine learning'], | ||
install_requires=[ | ||
'tensorflow', | ||
'hexdump' | ||
], | ||
classifiers=[ | ||
'Development Status :: 3 - Alpha', | ||
'Intended Audience :: Developers', | ||
'Topic :: Software Development :: Code Generators', | ||
'License :: OSI Approved :: MIT License', | ||
'Programming Language :: Python :: 3', | ||
'Programming Language :: Python :: 3.4', | ||
'Programming Language :: Python :: 3.5', | ||
'Programming Language :: Python :: 3.6', | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import re | ||
import hexdump | ||
import tensorflow as tf | ||
|
||
|
||
def port(model, optimize=True, variable_name='model_data', pretty_print=False): | ||
converter = tf.lite.TFLiteConverter.from_keras_model(model) | ||
if optimize: | ||
if isinstance(optimize, bool): | ||
optimizers = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] | ||
else: | ||
optimizers = optimize | ||
converter.optimizations = optimizers | ||
tflite_model = converter.convert() | ||
bytes = hexdump.dump(tflite_model).split(' ') | ||
c_array = ', '.join(['0x%02x' % int(byte, 16) for byte in bytes]) | ||
c = 'const unsigned char %s[] DATA_ALIGN_ATTRIBUTE = {%s};' % (variable_name, c_array) | ||
if pretty_print: | ||
c = c.replace('{', '{\n\t').replace('}', '\n}') | ||
c = re.sub(r'(0x..?, ){12}', lambda x: '%s\n\t' % x.group(0), c) | ||
c += '\nconst int %s_len = %d;' % (variable_name, len(bytes)) | ||
preamble = ''' | ||
#ifdef __has_attribute | ||
#define HAVE_ATTRIBUTE(x) __has_attribute(x) | ||
#else | ||
#define HAVE_ATTRIBUTE(x) 0 | ||
#endif | ||
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) | ||
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) | ||
#else | ||
#define DATA_ALIGN_ATTRIBUTE | ||
#endif | ||
''' | ||
return preamble + c |