Skip to content

Commit

Permalink
fix: tflite conversion and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed May 19, 2024
1 parent ec3ccc2 commit a4d411d
Show file tree
Hide file tree
Showing 20 changed files with 317 additions and 230 deletions.
9 changes: 0 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,3 @@ repos:
stages: [pre-commit]
fail_fast: true
verbose: true
- id: pylint-check
name: pylint-check
entry: pylint --rcfile=.pylintrc -rn -sn
language: system
types: [python]
stages: [pre-commit]
fail_fast: true
require_serial: true
verbose: true
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=tensorflow.python
generated-members=tensorflow.python,tensorflow.keras

# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ See [augmentations](./tensorflow_asr/augmentations/README.md)

## TFLite Convertion

After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **unicode code points**, then we can convert unicode points to string.
After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **text and tokens**

See [tflite_convertion](./docs/tutorials/tflite.md)

Expand Down
62 changes: 58 additions & 4 deletions docs/tutorials/tflite.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,66 @@
# TFLite Conversion Tutorial
- [TFLite Tutorial](#tflite-tutorial)
- [Conversion](#conversion)
- [Inference](#inference)
- [1. Input](#1-input)
- [2. Output](#2-output)
- [3. Example script](#3-example-script)

## Run

# TFLite Tutorial

## Conversion

```bash
python examples/train.py \
python3 examples/train.py \
--config-path=/path/to/config.yml.j2 \
--h5=/path/to/weight.h5 \
--bs=1 \ # Batch size
--beam-width=0 \ # Beam width, set >0 to enable beam search
--output=/path/to/output.tflite
## See others params
python examples/tflite.py --help
```
```

## Inference

### 1. Input

Input of each tflite depends on the models' parameters and configs.

The `inputs`, `inputs_length` and `previous_tokens` are still the same as bellow for all models.

```python
schemas.PredictInput(
inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32),
inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32),
previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)),
previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)),
previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)),
)
```

For models that don't have encoder states or decoder states, the default values are `tf.zeros([], dtype=self.dtype)` tensors for `previous_encoder_states` and `previous_decoder_states`. This is just for tflite conversion because tflite does not allow `None` value in `input_signature`. However, the output `next_encoder_states` and `next_decoder_states` are still `None`, so we can simply ignore those outputs.

### 2. Output

```python
schemas.PredictOutputWithTranscript(
transcript=self.tokenizer.detokenize(outputs.tokens),
tokens=outputs.tokens,
next_tokens=outputs.next_tokens,
next_encoder_states=outputs.next_encoder_states,
next_decoder_states=outputs.next_decoder_states,
)
```

This is for supporting streaming inference.

Each output corresponds to the input = each chunk of audio signal.

Then we can overwrite `previous_tokens`, `previous_encoder_states` and `previous_decoder_states` with `next_tokens`, `next_encoder_states` and `next_decoder_states` for the next chunk of audio signal.

And continue until the end of the audio signal.

### 3. Example script

See [examples/inferences/tflite.py](../../examples/inferences/tflite.py) for more details.
48 changes: 32 additions & 16 deletions examples/inferences/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,57 @@
# limitations under the License.

import tensorflow as tf
import tensorflow_text as tft
from tensorflow.lite.python import interpreter

from tensorflow_asr.utils import cli_util, data_util

logger = tf.get_logger()


def main(
file_path: str,
tflite_path: str,
previous_encoder_states_shape: list = None,
previous_decoder_states_shape: list = None,
blank_index: int = 0,
audio_file_path: str,
tflite: str,
sample_rate: int = 16000,
blank: int = 0,
):
tflitemodel = tf.lite.Interpreter(model_path=tflite_path)
signal = data_util.read_raw_audio(file_path)
wav = data_util.load_and_convert_to_wav(audio_file_path, sample_rate=sample_rate)
signal = data_util.read_raw_audio(wav)
signal = tf.reshape(signal, [1, -1])
signal_length = tf.reshape(tf.shape(signal)[1], [1])

tflitemodel = interpreter.InterpreterWithCustomOps(model_path=tflite, custom_op_registerers=tft.tflite_registrar.SELECT_TFTEXT_OPS)
input_details = tflitemodel.get_input_details()
output_details = tflitemodel.get_output_details()
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)

tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape, strict=True)
tflitemodel.allocate_tensors()
tflitemodel.set_tensor(input_details[0]["index"], signal)
tflitemodel.set_tensor(input_details[1]["index"], signal_length)
tflitemodel.set_tensor(input_details[2]["index"], tf.constant(blank_index, dtype=tf.int32))
if previous_encoder_states_shape:
tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(previous_encoder_states_shape, dtype=tf.float32))
if previous_decoder_states_shape:
tflitemodel.set_tensor(input_details[5]["index"], tf.zeros(previous_decoder_states_shape, dtype=tf.float32))
tflitemodel.set_tensor(input_details[2]["index"], tf.ones(input_details[2]["shape"], dtype=input_details[2]["dtype"]) * blank)
tflitemodel.set_tensor(input_details[3]["index"], tf.zeros(input_details[3]["shape"], dtype=input_details[3]["dtype"]))
tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(input_details[4]["shape"], dtype=input_details[4]["dtype"]))

tflitemodel.invoke()
hyp = tflitemodel.get_tensor(output_details[0]["index"])

transcript = "".join([chr(u) for u in hyp])
transcript = tflitemodel.get_tensor(output_details[0]["index"])
tokens = tflitemodel.get_tensor(output_details[1]["index"])
next_tokens = tflitemodel.get_tensor(output_details[2]["index"])
if len(output_details) > 4:
next_encoder_states = tflitemodel.get_tensor(output_details[3]["index"])
next_decoder_states = tflitemodel.get_tensor(output_details[4]["index"])
elif len(output_details) > 3:
next_encoder_states = None
next_decoder_states = tflitemodel.get_tensor(output_details[3]["index"])
else:
next_encoder_states = None
next_decoder_states = None

logger.info(f"Transcript: {transcript}")
return transcript
logger.info(f"Tokens: {tokens}")
logger.info(f"Next tokens: {next_tokens}")
logger.info(f"Next encoder states: {None if next_encoder_states is None else next_encoder_states.shape}")
logger.info(f"Next decoder states: {None if next_decoder_states is None else next_decoder_states.shape}")


if __name__ == "__main__":
Expand Down
47 changes: 0 additions & 47 deletions examples/models/transducer/conformer/inference/run_tflite_model.py

This file was deleted.

13 changes: 7 additions & 6 deletions examples/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,27 @@

def main(
config_path: str,
h5: str,
output: str,
h5: str = None,
bs: int = 1,
beam_width: int = 0,
repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")),
):
assert h5 and output
assert output
tf.keras.backend.clear_session()
env_util.setup_seed()
tf.compat.v1.enable_control_flow_v2()

config = Config(config_path, training=False, repodir=repodir)
tokenizer = tokenizers.get(config)

model: BaseModel = tf.keras.models.model_from_config(config.model_config)
model.tokenizer = tokenizer
model.make()
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5))
model.make(batch_size=bs)
if h5 and tf.io.gfile.exists(h5):
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5))
model.summary()

app_util.convert_tflite(model=model, output=output, batch_size=bs)
app_util.convert_tflite(model=model, output=output, batch_size=bs, beam_width=beam_width)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ sounddevice~=0.4.6
jinja2~=3.1.3
fire~=0.5.0
jiwer~=3.0.3
chardet~=5.1.0
charset-normalizer~=2.1.1

# extra=dev
pytest~=7.4.1
black~=24.3.0
pylint~=3.1.0
pylint~=3.2.1
matplotlib~=3.7.2
pydot~=1.4.2
graphviz~=0.20.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def parse_requirements(lines: List[str]):

setup(
name="TensorFlowASR",
version="2.0.0",
version="2.0.1",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_asr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = os.environ.get("TF_FORCE_GPU_ALLOW_GROWTH", "true")

import tensorflow as tf
import keras
from tensorflow.python.util import deprecation # pylint: disable = no-name-in-module

# might cause performance penalty if ops fallback to cpu, see https://cloud.google.com/tpu/docs/tensorflow-ops
Expand Down Expand Up @@ -46,9 +47,9 @@ def match_dtype_and_rank(y_t, y_p, sw):


# monkey patch
tf.keras.layers.Layer.output_shape = output_shape
tf.keras.layers.Layer.build = build
tf.keras.layers.Layer.compute_output_shape = compute_output_shape
keras.layers.Layer.output_shape = output_shape
keras.layers.Layer.build = build
keras.layers.Layer.compute_output_shape = compute_output_shape
compile_utils.match_dtype_and_rank = match_dtype_and_rank

import tensorflow_asr.callbacks
Expand Down
Loading

0 comments on commit a4d411d

Please sign in to comment.