Skip to content

Commit

Permalink
Lce benchmark and interpreter flags (#717)
Browse files Browse the repository at this point in the history
* add boolean flag allowing to register indirect BGEMM kernel

* added boolean flag  to LceInterpreter allowing to register indirect BGEMM kernels

* added boolean flag use_xnnpack in LceInterpreter to explicitly activate/deactivate XNNPACK delegate

* lce_benchmark_model:
- added two cmdline flags use_reference_bconv/use_indirect_bgemm as global variables in lce_benchmark_main.cc to register respective kernels
- implemented LceBenchmarkTfLiteModel as a child class of BenchmarkTfLiteModel to use builtin flags instead of manually parsing them in lce_benchmark_main.cc
- modified lce_benchmark_main.cc to use LceBenchmarkTfLiteModel, the global flags are set upon calling overriden Run() method by passing them as an internal reference
- added build options for lce_benchmark_tflite_model.h same as in TFLite's benchmark_tflite_model.h

* add include of LceBenchmarkTfLiteModel for cmake build

* srcs for Makefile

* added warning when use_reference_bconv and use_indirect_bgemm are both set to true in lce_ops_register.h

* adapted BUILD file to include TFLite logging header

* include TFLite logging header for interpreter_wrapper_lite build
  • Loading branch information
simonmaurer committed Mar 10, 2022
1 parent d8193e4 commit 4cb8e72
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 13 deletions.
15 changes: 15 additions & 0 deletions larq_compute_engine/tflite/benchmark/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,23 @@ tf_cc_binary(
],
}),
deps = [
"//larq_compute_engine/tflite/benchmark:lce_benchmark_tflite_model_lib",
"//larq_compute_engine/tflite/kernels:lce_op_kernels",
"@org_tensorflow//tensorflow/lite/tools:logging",
],
)

cc_library(
name = "lce_benchmark_tflite_model_lib",
srcs = ["lce_benchmark_tflite_model.cc"],
hdrs = ["lce_benchmark_tflite_model.h"],
copts = tflite_copts() + select({
"@org_tensorflow//tensorflow:ios": [
"-xobjective-c++",
],
"//conditions:default": [],
}),
deps = [
"@org_tensorflow//tensorflow/lite/tools/benchmark:benchmark_tflite_model_lib",
],
)
12 changes: 9 additions & 3 deletions larq_compute_engine/tflite/benchmark/lce_benchmark_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,29 @@ limitations under the License.
==============================================================================*/

#include <iostream>
#include <string>

#include "absl/base/attributes.h"
#include "larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h"
#include "larq_compute_engine/tflite/kernels/lce_ops_register.h"
#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
#include "tensorflow/lite/tools/logging.h"

bool use_reference_bconv = false;
bool use_indirect_bgemm = false;

void ABSL_ATTRIBUTE_WEAK
RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {
compute_engine::tflite::RegisterLCECustomOps(resolver);
compute_engine::tflite::RegisterLCECustomOps(resolver, use_reference_bconv,
use_indirect_bgemm);
}

namespace tflite {
namespace benchmark {

int Main(int argc, char** argv) {
TFLITE_LOG(INFO) << "STARTING!";
BenchmarkTfLiteModel benchmark;
LceBenchmarkTfLiteModel benchmark(LceBenchmarkTfLiteModel::DefaultParams(),
use_reference_bconv, use_indirect_bgemm);
if (benchmark.Run(argc, argv) != kTfLiteOk) {
TFLITE_LOG(ERROR) << "Benchmarking failed.";
return EXIT_FAILURE;
Expand Down
74 changes: 74 additions & 0 deletions larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Modifications copyright (C) 2022 Larq Contributors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h"

#include "tensorflow/lite/tools/logging.h"

namespace tflite {
namespace benchmark {

BenchmarkParams LceBenchmarkTfLiteModel::DefaultParams() {
BenchmarkParams default_params = BenchmarkTfLiteModel::DefaultParams();
default_params.AddParam("use_reference_bconv",
BenchmarkParam::Create<bool>(false));
default_params.AddParam("use_indirect_bgemm",
BenchmarkParam::Create<bool>(false));

return default_params;
}

LceBenchmarkTfLiteModel::LceBenchmarkTfLiteModel(BenchmarkParams params,
bool& use_reference_bconv,
bool& use_indirect_bgemm)
: BenchmarkTfLiteModel(std::move(params)),
use_reference_bconv(use_reference_bconv),
use_indirect_bgemm(use_indirect_bgemm) {}

std::vector<Flag> LceBenchmarkTfLiteModel::GetFlags() {
std::vector<Flag> flags = BenchmarkTfLiteModel::GetFlags();
std::vector<Flag> lce_flags = {
CreateFlag<bool>(
"use_reference_bconv", &params_,
"When true, uses the reference implementation of LceBconv2d."),
CreateFlag<bool>("use_indirect_bgemm", &params_,
"When true, uses the optimized indirect BGEMM kernel of"
"LceBconv2d.")};

flags.insert(flags.end(), lce_flags.begin(), lce_flags.end());

return flags;
}

void LceBenchmarkTfLiteModel::LogParams() {
BenchmarkTfLiteModel::LogParams();
const bool verbose = params_.Get<bool>("verbose");
LOG_BENCHMARK_PARAM(bool, "use_reference_bconv", "Use reference Bconv",
verbose);
LOG_BENCHMARK_PARAM(bool, "use_indirect_bgemm", "Use indirect BGEMM",
verbose);
}

TfLiteStatus LceBenchmarkTfLiteModel::Run(int argc, char** argv) {
TF_LITE_ENSURE_STATUS(ParseFlags(argc, argv));
use_reference_bconv = params_.Get<bool>("use_reference_bconv");
use_indirect_bgemm = params_.Get<bool>("use_indirect_bgemm");

return BenchmarkTfLiteModel::Run();
}

} // namespace benchmark
} // namespace tflite
47 changes: 47 additions & 0 deletions larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Modifications copyright (C) 2022 Larq Contributors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_
#define COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_

#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"

namespace tflite {
namespace benchmark {

// Benchmarks a TFLite model by running tflite interpreter.
class LceBenchmarkTfLiteModel : public BenchmarkTfLiteModel {
public:
explicit LceBenchmarkTfLiteModel(BenchmarkParams params,
bool& use_reference_bconv,
bool& use_indirect_bgemm);

std::vector<Flag> GetFlags() override;
void LogParams() override;
static BenchmarkParams DefaultParams();

using BenchmarkTfLiteModel::Run;
TfLiteStatus Run(int argc, char** argv);

private:
bool& use_reference_bconv;
bool& use_indirect_bgemm;
};

} // namespace benchmark
} // namespace tflite

#endif // COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_
1 change: 1 addition & 0 deletions larq_compute_engine/tflite/build_make/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ LCE_EXAMPLE_SRCS := \
examples/lce_minimal.cc

LCE_BENCHMARK_SRCS := \
larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc \
larq_compute_engine/tflite/benchmark/lce_benchmark_main.cc

# These target-specific makefiles should modify or replace options like
Expand Down
1 change: 1 addition & 0 deletions larq_compute_engine/tflite/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cc_library(
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
"@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"@org_tensorflow//tensorflow/lite/tools:logging",
"@ruy//ruy/profiler:instrumentation",
],
alwayslink = 1,
Expand Down
23 changes: 20 additions & 3 deletions larq_compute_engine/tflite/kernels/lce_ops_register.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@

#include "tensorflow/lite/context.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/tools/logging.h"

// This file contains forward declaration of all custom ops
// implemented in LCE which can be used to link against LCE library.

namespace compute_engine {
namespace tflite {

using namespace ::tflite;

TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_DEQUANTIZE();
TfLiteRegistration* Register_BCONV_2D();
TfLiteRegistration* Register_BCONV_2D_REF();
TfLiteRegistration* Register_BCONV_2D_OPT_INDIRECT_BGEMM();
TfLiteRegistration* Register_BMAXPOOL_2D();

// By calling this function on TF lite mutable op resolver, all LCE custom ops
// will be registerd to the op resolver.
inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver,
const bool use_reference_bconv = false) {
const bool use_reference_bconv = false,
const bool use_indirect_bgemm = false) {
if (use_reference_bconv && use_indirect_bgemm) {
TFLITE_LOG(WARN)
<< "WARNING: 'use_reference_bconv' and `use_indirect_bgemm` "
"are both set to true. use_indirect_bgemm==true "
"will have no effect.";
}
resolver->AddCustom("LceQuantize",
compute_engine::tflite::Register_QUANTIZE());
resolver->AddCustom("LceDequantize",
Expand All @@ -28,8 +39,14 @@ inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver,
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D_REF());
} else {
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D());
if (use_indirect_bgemm) {
resolver->AddCustom(
"LceBconv2d",
compute_engine::tflite::Register_BCONV_2D_OPT_INDIRECT_BGEMM());
} else {
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D());
}
}
resolver->AddCustom("LceBMaxPool2d",
compute_engine::tflite::Register_BMAXPOOL_2D());
Expand Down
10 changes: 9 additions & 1 deletion larq_compute_engine/tflite/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Interpreter(InterpreterBase):
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
num_threads: The number of threads used by the interpreter.
use_reference_bconv: When True, uses the reference implementation of LceBconv2d.
use_indirect_bgemm: When True, uses the optimized indirect BGEMM kernel of LceBconv2d.
use_xnnpack: When True, uses the XNNPack delegate of TFLite.
# Attributes
input_types: Returns a list of input types.
Expand All @@ -40,11 +42,17 @@ def __init__(
flatbuffer_model: bytes,
num_threads: int = 1,
use_reference_bconv: bool = False,
use_indirect_bgemm: bool = False,
use_xnnpack: bool = False,
):
from larq_compute_engine.tflite.python import interpreter_wrapper_lite

super().__init__(
interpreter_wrapper_lite.LiteInterpreter(
flatbuffer_model, num_threads, use_reference_bconv
flatbuffer_model,
num_threads,
use_reference_bconv,
use_indirect_bgemm,
use_xnnpack,
)
)
21 changes: 15 additions & 6 deletions larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class LiteInterpreterWrapper
public:
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer,
const int num_threads = 1,
const bool use_reference_bconv = false);
const bool use_reference_bconv = false,
const bool use_indirect_bgemm = false,
const bool use_xnnpack = false);
~LiteInterpreterWrapper(){};

private:
Expand All @@ -25,7 +27,8 @@ class LiteInterpreterWrapper

LiteInterpreterWrapper::LiteInterpreterWrapper(
const pybind11::bytes& flatbuffer, const int num_threads,
const bool use_reference_bconv) {
const bool use_reference_bconv, const bool use_indirect_bgemm,
const bool use_xnnpack) {
// Make a copy of the flatbuffer because it can get deallocated after the
// constructor is done
flatbuffer_ = static_cast<std::string>(flatbuffer);
Expand All @@ -37,9 +40,14 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
}

// Build the interpreter
resolver_ = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
compute_engine::tflite::RegisterLCECustomOps(resolver_.get(),
use_reference_bconv);
if (use_xnnpack) {
resolver_ = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
} else {
resolver_ = std::make_unique<
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
}
compute_engine::tflite::RegisterLCECustomOps(
resolver_.get(), use_reference_bconv, use_indirect_bgemm);

tflite::InterpreterBuilder builder(*model_, *resolver_);
builder(&interpreter_, num_threads);
Expand All @@ -51,7 +59,8 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(

PYBIND11_MODULE(interpreter_wrapper_lite, m) {
pybind11::class_<LiteInterpreterWrapper>(m, "LiteInterpreter")
.def(pybind11::init<const pybind11::bytes&, const int, const bool>())
.def(pybind11::init<const pybind11::bytes&, const int, const bool,
const bool, const bool>())
.def_property("input_types", &LiteInterpreterWrapper::get_input_types,
nullptr)
.def_property("output_types", &LiteInterpreterWrapper::get_output_types,
Expand Down

0 comments on commit 4cb8e72

Please sign in to comment.