Skip to content

Commit

Permalink
tests/pkg_tensorflow-lite: adapt and rename to tflite-micro
Browse files Browse the repository at this point in the history
  • Loading branch information
aabadie committed Apr 8, 2022
1 parent 21ea4e7 commit 9b826b6
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 41 deletions.
23 changes: 0 additions & 23 deletions tests/pkg_tensorflow-lite/Makefile

This file was deleted.

14 changes: 14 additions & 0 deletions tests/pkg_tflite-micro/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Ensure minimal size by default
DEVELHELP ?= 0

include ../Makefile.tests_common

USEPKG += tflite-micro

# TensorFlow-Lite crashes on M4/M7 CPUs when FPU is enabled, so disable it by
# default for now
DISABLE_MODULE += cortexm_fpu
USEMODULE += mnist
EXTERNAL_MODULE_DIRS += external_modules

include $(RIOTBASE)/Makefile.include
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,37 @@ BOARD_INSUFFICIENT_MEMORY := \
arduino-mkrfox1200 \
arduino-mkrwan1300 \
arduino-mkrzero \
arduino-nano-33-iot \
arduino-zero \
b-l072z-lrwan1 \
bastwan \
blackpill \
blackpill-128kib \
bluepill \
bluepill-128kib \
bluepill-stm32f030c8 \
calliope-mini \
cc1350-launchpad \
cc2650-launchpad \
cc2650stk \
e104-bt5010a-tb \
e104-bt5011a-tb \
e180-zg120b-tb \
esp8266-esp-12x \
esp8266-olimex-mod \
esp8266-sparkfun-thing \
feather-m0 \
feather-m0-lora \
feather-m0-wifi \
frdm-kl43z \
frdm-kw41z \
hamilton \
i-nucleo-lrwan1 \
ikea-tradfri \
im880b \
lobaro-lorabox \
lsn50 \
maple-mini \
microbit \
nrf51dk \
nrf51dongle \
Expand All @@ -23,36 +45,60 @@ BOARD_INSUFFICIENT_MEMORY := \
nucleo-f070rb \
nucleo-f072rb \
nucleo-f091rc \
nucleo-f103rb \
nucleo-f302r8 \
nucleo-f303k8 \
nucleo-f334r8 \
nucleo-f410rb \
nucleo-g070rb \
nucleo-g071rb \
nucleo-g431kb \
nucleo-g431rb \
nucleo-l011k4 \
nucleo-l031k6 \
nucleo-l053r8 \
nucleo-l073rz \
nucleo-l412kb \
nucleo-wl55jc \
olimexino-stm32 \
opencm904 \
openlabs-kw41z-mini-256kib \
pba-d-01-kw2x \
phynode-kw41z \
samd10-xmini \
samd20-xpro \
samd21-xpro \
saml10-xpro \
saml11-xpro \
saml21-xpro \
samr21-xpro \
samr30-xpro \
samr34-xpro \
seeeduino_xiao \
sensebox_samd21 \
serpente \
slstk3400a \
slstk3401a \
sltb001a \
slwstk6000b-slwrb4150a \
slwstk6220a \
sodaq-autonomo \
sodaq-explorer \
sodaq-one \
sodaq-sara-aff \
sodaq-sara-sff \
spark-core \
stk3200 \
stk3600 \
stm32f030f4-demo \
stm32f0discovery \
stm32g0316-disco \
stm32l0538-disco \
stm32mp157c-dk2 \
teensy31 \
usb-kw41z \
weact-f401cc \
wemos-zero \
yarm \
yunjia-nrf51822 \
#
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
*/

#include <stdio.h>
#include "kernel_defines.h"
#if IS_USED(MODULE_TENSORFLOW_LITE)
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
#else
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/system_setup.h"
#endif
#include "tensorflow/lite/schema/schema_generated.h"

#include "blob/digit.h"
#include "blob/model.tflite.h"
Expand All @@ -45,39 +52,38 @@ namespace {
// The name of this function is important for Arduino compatibility.
void setup()
{
#if IS_USED(MODULE_TFLITE_MICRO)
tflite::InitializeTarget();
#endif

// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;

// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(model_tflite);

if (model->version() != TFLITE_SCHEMA_VERSION) {
printf("Model provided is schema version %d not equal "
"to supported version %d.",
static_cast<uint8_t>(model->version()), TFLITE_SCHEMA_VERSION);
return;
}

// Explicitly load required operators
static tflite::MicroMutableOpResolver micro_mutable_op_resolver;
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED(), 1, 4);
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_SOFTMAX,
tflite::ops::micro::Register_SOFTMAX(), 1, 2);
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_QUANTIZE,
tflite::ops::micro::Register_QUANTIZE());
micro_mutable_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEQUANTIZE,
tflite::ops::micro::Register_DEQUANTIZE(), 1, 2);
// This pulls in all the operation implementations we need.
// NOLINTNEXTLINE(runtime-global-variables)
#if IS_USED(MODULE_TFLITE_MICRO)
static tflite::AllOpsResolver resolver;
#else
static tflite::ops::micro::AllOpsResolver resolver;
#endif

// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, micro_mutable_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;

// Allocate memory from the tensor_arena for the model's tensors.
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 9b826b6

Please sign in to comment.