diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d9f842545..79507f6ee4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,3 @@ repos: stages: [pre-commit] fail_fast: true verbose: true - - id: pylint-check - name: pylint-check - entry: pylint --rcfile=.pylintrc -rn -sn - language: system - types: [python] - stages: [pre-commit] - fail_fast: true - require_serial: true - verbose: true diff --git a/.pylintrc b/.pylintrc index 9829b0e210..ca5736a5f2 100644 --- a/.pylintrc +++ b/.pylintrc @@ -216,7 +216,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members=tensorflow.python +generated-members=tensorflow.python,tensorflow.keras # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). diff --git a/README.md b/README.md index 793d64bf66..d79673cd2e 100755 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ See [augmentations](./tensorflow_asr/augmentations/README.md) ## TFLite Convertion -After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **unicode code points**, then we can convert unicode points to string. +After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **text and tokens** See [tflite_convertion](./docs/tutorials/tflite.md) diff --git a/docs/tutorials/tflite.md b/docs/tutorials/tflite.md index 1f2d31cc3f..25f29b0d97 100644 --- a/docs/tutorials/tflite.md +++ b/docs/tutorials/tflite.md @@ -1,12 +1,66 @@ -# TFLite Conversion Tutorial +- [TFLite Tutorial](#tflite-tutorial) + - [Conversion](#conversion) + - [Inference](#inference) + - [1. Input](#1-input) + - [2. Output](#2-output) + - [3. Example script](#3-example-script) -## Run + +# TFLite Tutorial + +## Conversion ```bash -python examples/train.py \ +python3 examples/train.py \ --config-path=/path/to/config.yml.j2 \ --h5=/path/to/weight.h5 \ + --bs=1 \ # Batch size + --beam-width=0 \ # Beam width, set >0 to enable beam search --output=/path/to/output.tflite ## See others params python examples/tflite.py --help -``` \ No newline at end of file +``` + +## Inference + +### 1. Input + +Input of each tflite depends on the models' parameters and configs. + +The `inputs`, `inputs_length` and `previous_tokens` are still the same as bellow for all models. + +```python +schemas.PredictInput( + inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32), + inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32), + previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)), + previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)), + previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)), +) +``` + +For models that don't have encoder states or decoder states, the default values are `tf.zeros([], dtype=self.dtype)` tensors for `previous_encoder_states` and `previous_decoder_states`. This is just for tflite conversion because tflite does not allow `None` value in `input_signature`. However, the output `next_encoder_states` and `next_decoder_states` are still `None`, so we can simply ignore those outputs. + +### 2. Output + +```python +schemas.PredictOutputWithTranscript( + transcript=self.tokenizer.detokenize(outputs.tokens), + tokens=outputs.tokens, + next_tokens=outputs.next_tokens, + next_encoder_states=outputs.next_encoder_states, + next_decoder_states=outputs.next_decoder_states, +) +``` + +This is for supporting streaming inference. + +Each output corresponds to the input = each chunk of audio signal. + +Then we can overwrite `previous_tokens`, `previous_encoder_states` and `previous_decoder_states` with `next_tokens`, `next_encoder_states` and `next_decoder_states` for the next chunk of audio signal. + +And continue until the end of the audio signal. + +### 3. Example script + +See [examples/inferences/tflite.py](../../examples/inferences/tflite.py) for more details. \ No newline at end of file diff --git a/examples/inferences/tflite.py b/examples/inferences/tflite.py index 18dd22d3e9..e2dfe5b910 100644 --- a/examples/inferences/tflite.py +++ b/examples/inferences/tflite.py @@ -13,6 +13,8 @@ # limitations under the License. import tensorflow as tf +import tensorflow_text as tft +from tensorflow.lite.python import interpreter from tensorflow_asr.utils import cli_util, data_util @@ -20,34 +22,48 @@ def main( - file_path: str, - tflite_path: str, - previous_encoder_states_shape: list = None, - previous_decoder_states_shape: list = None, - blank_index: int = 0, + audio_file_path: str, + tflite: str, + sample_rate: int = 16000, + blank: int = 0, ): - tflitemodel = tf.lite.Interpreter(model_path=tflite_path) - signal = data_util.read_raw_audio(file_path) + wav = data_util.load_and_convert_to_wav(audio_file_path, sample_rate=sample_rate) + signal = data_util.read_raw_audio(wav) signal = tf.reshape(signal, [1, -1]) signal_length = tf.reshape(tf.shape(signal)[1], [1]) + tflitemodel = interpreter.InterpreterWithCustomOps(model_path=tflite, custom_op_registerers=tft.tflite_registrar.SELECT_TFTEXT_OPS) input_details = tflitemodel.get_input_details() output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) + + tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape, strict=True) tflitemodel.allocate_tensors() tflitemodel.set_tensor(input_details[0]["index"], signal) tflitemodel.set_tensor(input_details[1]["index"], signal_length) - tflitemodel.set_tensor(input_details[2]["index"], tf.constant(blank_index, dtype=tf.int32)) - if previous_encoder_states_shape: - tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(previous_encoder_states_shape, dtype=tf.float32)) - if previous_decoder_states_shape: - tflitemodel.set_tensor(input_details[5]["index"], tf.zeros(previous_decoder_states_shape, dtype=tf.float32)) + tflitemodel.set_tensor(input_details[2]["index"], tf.ones(input_details[2]["shape"], dtype=input_details[2]["dtype"]) * blank) + tflitemodel.set_tensor(input_details[3]["index"], tf.zeros(input_details[3]["shape"], dtype=input_details[3]["dtype"])) + tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(input_details[4]["shape"], dtype=input_details[4]["dtype"])) + tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - transcript = "".join([chr(u) for u in hyp]) + transcript = tflitemodel.get_tensor(output_details[0]["index"]) + tokens = tflitemodel.get_tensor(output_details[1]["index"]) + next_tokens = tflitemodel.get_tensor(output_details[2]["index"]) + if len(output_details) > 4: + next_encoder_states = tflitemodel.get_tensor(output_details[3]["index"]) + next_decoder_states = tflitemodel.get_tensor(output_details[4]["index"]) + elif len(output_details) > 3: + next_encoder_states = None + next_decoder_states = tflitemodel.get_tensor(output_details[3]["index"]) + else: + next_encoder_states = None + next_decoder_states = None + logger.info(f"Transcript: {transcript}") - return transcript + logger.info(f"Tokens: {tokens}") + logger.info(f"Next tokens: {next_tokens}") + logger.info(f"Next encoder states: {None if next_encoder_states is None else next_encoder_states.shape}") + logger.info(f"Next decoder states: {None if next_decoder_states is None else next_decoder_states.shape}") if __name__ == "__main__": diff --git a/examples/models/transducer/conformer/inference/run_tflite_model.py b/examples/models/transducer/conformer/inference/run_tflite_model.py deleted file mode 100644 index b4f6aa5cfb..0000000000 --- a/examples/models/transducer/conformer/inference/run_tflite_model.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fire -import tensorflow as tf - -from tensorflow_asr.features.speech_featurizers import read_raw_audio - - -def main( - filename: str, - tflite: str = None, - blank: int = 0, - num_rnns: int = 1, - nstates: int = 2, - statesize: int = 320, -): - tflitemodel = tf.lite.Interpreter(model_path=tflite) - - signal = read_raw_audio(filename) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) - tflitemodel.allocate_tensors() - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], tf.constant(blank, dtype=tf.int32)) - tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([num_rnns, nstates, 1, statesize], dtype=tf.float32)) - tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - - print("".join([chr(u) for u in hyp])) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/examples/tflite.py b/examples/tflite.py index a22c354e1a..f2d46c56ff 100644 --- a/examples/tflite.py +++ b/examples/tflite.py @@ -23,26 +23,27 @@ def main( config_path: str, - h5: str, output: str, + h5: str = None, bs: int = 1, + beam_width: int = 0, repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")), ): - assert h5 and output + assert output tf.keras.backend.clear_session() env_util.setup_seed() - tf.compat.v1.enable_control_flow_v2() config = Config(config_path, training=False, repodir=repodir) tokenizer = tokenizers.get(config) model: BaseModel = tf.keras.models.model_from_config(config.model_config) model.tokenizer = tokenizer - model.make() - model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5)) + model.make(batch_size=bs) + if h5 and tf.io.gfile.exists(h5): + model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5)) model.summary() - app_util.convert_tflite(model=model, output=output, batch_size=bs) + app_util.convert_tflite(model=model, output=output, batch_size=bs, beam_width=beam_width) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index f494d28f2f..eb9076cf47 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,13 +8,11 @@ sounddevice~=0.4.6 jinja2~=3.1.3 fire~=0.5.0 jiwer~=3.0.3 -chardet~=5.1.0 -charset-normalizer~=2.1.1 # extra=dev pytest~=7.4.1 black~=24.3.0 -pylint~=3.1.0 +pylint~=3.2.1 matplotlib~=3.7.2 pydot~=1.4.2 graphviz~=0.20.1 diff --git a/setup.py b/setup.py index 3eef87d682..9a6949bc9b 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def parse_requirements(lines: List[str]): setup( name="TensorFlowASR", - version="2.0.0", + version="2.0.1", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/__init__.py b/tensorflow_asr/__init__.py index d613bf9a22..5a23e7f535 100644 --- a/tensorflow_asr/__init__.py +++ b/tensorflow_asr/__init__.py @@ -7,6 +7,7 @@ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = os.environ.get("TF_FORCE_GPU_ALLOW_GROWTH", "true") import tensorflow as tf +import keras from tensorflow.python.util import deprecation # pylint: disable = no-name-in-module # might cause performance penalty if ops fallback to cpu, see https://cloud.google.com/tpu/docs/tensorflow-ops @@ -46,9 +47,9 @@ def match_dtype_and_rank(y_t, y_p, sw): # monkey patch -tf.keras.layers.Layer.output_shape = output_shape -tf.keras.layers.Layer.build = build -tf.keras.layers.Layer.compute_output_shape = compute_output_shape +keras.layers.Layer.output_shape = output_shape +keras.layers.Layer.build = build +keras.layers.Layer.compute_output_shape = compute_output_shape compile_utils.match_dtype_and_rank = match_dtype_and_rank import tensorflow_asr.callbacks diff --git a/tensorflow_asr/models/base_model.py b/tensorflow_asr/models/base_model.py index 664e4f9825..731fc68d02 100644 --- a/tensorflow_asr/models/base_model.py +++ b/tensorflow_asr/models/base_model.py @@ -129,6 +129,7 @@ def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, cac predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) predictions_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self._per_replica_batch_size = int(batch_size / self.distribute_strategy.num_replicas_in_sync) + self._batch_size = batch_size outputs = self( schemas.TrainInput( inputs=signals, @@ -277,7 +278,13 @@ def test_step(self, data): def predict_step(self, data): x, y_true = data - inputs = schemas.PredictInput(x["inputs"], x["inputs_length"]) + inputs = schemas.PredictInput( + inputs=x["inputs"], + inputs_length=x["inputs_length"], + previous_tokens=self.get_initial_tokens(), + previous_encoder_states=self.get_initial_encoder_states(), + previous_decoder_states=self.get_initial_decoder_states(), + ) _tokens = self.recognize(inputs=inputs).tokens _beam_tokens = self.recognize_beam(inputs=inputs).tokens return { @@ -619,41 +626,52 @@ def fit( # -------------------------------- INFERENCE FUNCTIONS ------------------------------------- + def get_initial_tokens(self, batch_size=1): + return tf.ones([batch_size, 1], dtype=tf.int32) * self.blank + + def get_initial_encoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) + + def get_initial_decoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) + def recognize(self, inputs: schemas.PredictInput, **kwargs) -> schemas.PredictOutput: """Greedy decoding function that used in self.predict_step""" raise NotImplementedError() - def recognize_beam(self, inputs: schemas.PredictInput, **kwargs) -> schemas.PredictOutput: + def recognize_beam(self, inputs: schemas.PredictInput, beam_width: int = 10, **kwargs) -> schemas.PredictOutput: """Beam search decoding function that used in self.predict_step""" raise NotImplementedError() # ---------------------------------- TFLITE ---------------------------------- # - def make_tflite_function(self, batch_size=1): - @tf.function( - input_signature=[ - tf.TensorSpec([batch_size, None], dtype=tf.float32), - tf.TensorSpec([batch_size], dtype=tf.int32), - tf.TensorSpec([batch_size, 1], dtype=tf.int32), - tf.TensorSpec(self.encoder.get_initial_state(batch_size), dtype=tf.float32) if hasattr(self.encoder, "get_initial_state") else None, - tf.TensorSpec(self.predict_net.get_initial_state(batch_size).get_shape(), dtype=tf.float32), - ], - ) - def tflite_func(inputs, inputs_length, previous_tokens, previous_encoder_states, previous_decoder_states): - outputs = self.recognize( - schemas.PredictInput( - inputs=inputs, - inputs_length=inputs_length, - previous_tokens=previous_tokens, - previous_encoder_states=previous_encoder_states, - previous_decoder_states=previous_decoder_states, - ) - ) - return schemas.PredictOutput( - tokens=self.tokenizer.detokenize_unicode_points(outputs.tokens), - scores=outputs.scores, - encoder_states=outputs.encoder_states, - decoder_states=outputs.decoder_states, + def make_tflite_function(self, batch_size: int = 1, beam_width: int = 0): + + def tflite_func(inputs: schemas.PredictInput): + if beam_width > 0: + outputs = self.recognize_beam(inputs, beam_width=beam_width) + else: + outputs = self.recognize(inputs) + return schemas.PredictOutputWithTranscript( + transcript=self.tokenizer.detokenize(outputs.tokens), + tokens=outputs.tokens, + next_tokens=outputs.next_tokens, + next_encoder_states=outputs.next_encoder_states, + next_decoder_states=outputs.next_decoder_states, ) - return tflite_func + input_signature = schemas.PredictInput( + inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32), + inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32), + previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)), + previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)), + previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)), + ) + + return tf.function( + tflite_func, + input_signature=[input_signature], + jit_compile=True, + reduce_retracing=True, + autograph=True, + ) diff --git a/tensorflow_asr/models/ctc/base_ctc.py b/tensorflow_asr/models/ctc/base_ctc.py index 219ac564ae..cd48572d83 100644 --- a/tensorflow_asr/models/ctc/base_ctc.py +++ b/tensorflow_asr/models/ctc/base_ctc.py @@ -94,6 +94,15 @@ def call_next( outputs, outputs_length, next_decoder_states = self.decoder.call_next(outputs, outputs_length, previous_decoder_states) return outputs, outputs_length, next_encoder_states, next_decoder_states + def get_initial_tokens(self, batch_size=1): + return super().get_initial_tokens(batch_size) + + def get_initial_encoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) + + def get_initial_decoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) + # -------------------------------- GREEDY ------------------------------------- def recognize(self, inputs: schemas.PredictInput, **kwargs): diff --git a/tensorflow_asr/models/ctc/deepspeech2.py b/tensorflow_asr/models/ctc/deepspeech2.py index f9dff373db..7fbf62afd8 100644 --- a/tensorflow_asr/models/ctc/deepspeech2.py +++ b/tensorflow_asr/models/ctc/deepspeech2.py @@ -131,3 +131,9 @@ def __init__( **kwargs, ) self.time_reduction_factor = self.encoder.time_reduction_factor + + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) + + def get_initial_decoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) diff --git a/tensorflow_asr/models/encoders/deepspeech2.py b/tensorflow_asr/models/encoders/deepspeech2.py index 2c50184c04..f3f81f14c7 100644 --- a/tensorflow_asr/models/encoders/deepspeech2.py +++ b/tensorflow_asr/models/encoders/deepspeech2.py @@ -224,6 +224,7 @@ def __init__( dropout=dropout, unroll=unroll, return_sequences=True, + return_state=True, use_bias=True, name=rnn_type, zero_output_for_mask=True, @@ -233,6 +234,7 @@ def __init__( bias_initializer=initializer, dtype=self.dtype, ) + self._bidirectional = bidirectional if bidirectional: self.rnn = tf.keras.layers.Bidirectional(self.rnn, name=f"b{rnn_type}", dtype=self.dtype) self.bn = tf.keras.layers.BatchNormalization( @@ -249,17 +251,34 @@ def __init__( dtype=self.dtype, ) + def get_initial_state(self, batch_size: int): + if self._bidirectional: + states = self.rnn.forward_layer.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=self.dtype)) + states += self.rnn.backward_layer.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=self.dtype)) + else: + states = self.rnn.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=self.dtype)) + return states + def call(self, inputs, training=False): outputs, outputs_length = inputs - outputs = self.rnn(outputs, training=training) # mask auto populate + outputs, *_ = self.rnn(outputs, training=training) # mask auto populate outputs = self.bn(outputs, training=training) if self.rowconv is not None: outputs = self.rowconv(outputs, training=training) return outputs, outputs_length + def call_next(self, inputs, previous_encoder_states): + with tf.name_scope(f"{self.name}_call_next"): + outputs, outputs_length = inputs + outputs, *_states = self.rnn(outputs, training=False, initial_state=tf.unstack(previous_encoder_states, axis=0)) + outputs = self.bn(outputs, training=False) + if self.rowconv is not None: + outputs = self.rowconv(outputs, training=False) + return outputs, outputs_length, tf.stack(_states) + def compute_output_shape(self, input_shape): output_shape, output_length_shape = input_shape - output_shape = self.rnn.compute_output_shape(output_shape) + output_shape, *_ = self.rnn.compute_output_shape(output_shape) output_shape = self.bn.compute_output_shape(output_shape) if self.rowconv is not None: output_shape = self.rowconv.compute_output_shape(output_shape) @@ -301,12 +320,35 @@ def __init__( for i in range(nlayers) ] + def get_initial_state(self, batch_size: int): + """ + Get zeros states + + Returns + ------- + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states + """ + states = [] + for block in self.blocks: + states.append(tf.stack(block.get_initial_state(batch_size=batch_size), axis=0)) + return tf.transpose(tf.stack(states, axis=0), perm=[2, 0, 1, 3]) + def call(self, inputs, training=False): outputs = inputs for block in self.blocks: outputs = block(outputs, training=training) return outputs + def call_next(self, inputs, previous_encoder_states): + outputs = inputs + previous_encoder_states = tf.transpose(previous_encoder_states, perm=[1, 2, 0, 3]) + new_states = [] + for i, block in enumerate(self.blocks): + *outputs, _states = block.call_next(outputs, previous_encoder_states=previous_encoder_states[i]) + new_states.append(_states) + return outputs, tf.transpose(tf.stack(new_states, axis=0), perm=[2, 0, 1, 3]) + def compute_output_shape(self, input_shape): output_shape = input_shape for block in self.blocks: @@ -471,6 +513,17 @@ def __init__( ) self.time_reduction_factor = self.conv_module.time_reduction_factor + def get_initial_state(self, batch_size: int): + """ + Get zeros states + + Returns + ------- + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states + """ + return self.rnn_module.get_initial_state(batch_size=batch_size) + def call(self, inputs, training=False): *outputs, caching = inputs outputs = self.conv_module(outputs, training=training) @@ -478,6 +531,27 @@ def call(self, inputs, training=False): outputs = self.fc_module(outputs, training=training) return *outputs, caching + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network from previous encoder states + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + previous_encoder_states : tf.Tensor, shape [B, nlayers, nstates, rnn_units] -> [nlayers, nstates, B, rnn_units] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], [nlayers, nstates, B, rnn_units] -> [B, nlayers, nstates, rnn_units]) + """ + with tf.name_scope(f"{self.name}_call_next"): + outputs = (features, features_length) + outputs = self.conv_module(outputs, training=False) + outputs, new_encoder_states = self.rnn_module.call_next(outputs, previous_encoder_states=previous_encoder_states) + outputs, outputs_length = self.fc_module(outputs, training=False) + return outputs, outputs_length, new_encoder_states + def compute_mask(self, inputs, mask=None): *outputs, caching = inputs return *self.conv_module.compute_mask(outputs, mask=mask), getattr(caching, "_keras_mask", None) diff --git a/tensorflow_asr/models/transducer/base_transducer.py b/tensorflow_asr/models/transducer/base_transducer.py index e632ac779c..b1eb075e9e 100644 --- a/tensorflow_asr/models/transducer/base_transducer.py +++ b/tensorflow_asr/models/transducer/base_transducer.py @@ -444,6 +444,15 @@ def call_next( ytu = tf.nn.log_softmax(ytu) return ytu, new_states + def get_initial_tokens(self, batch_size=1): + return super().get_initial_tokens(batch_size) + + def get_initial_encoder_states(self, batch_size=1): + return tf.zeros([], dtype=self.dtype) + + def get_initial_decoder_states(self, batch_size=1): + return self.predict_net.get_initial_state(batch_size) + # -------------------------------- GREEDY ------------------------------------- def recognize(self, inputs: schemas.PredictInput, max_tokens_per_frame: int = 3, **kwargs): @@ -464,11 +473,9 @@ def recognize(self, inputs: schemas.PredictInput, max_tokens_per_frame: int = 3, next_decoder_states, next states of predict_net, will be used to predict next chunk of audio, ) """ - return tf.cond( - tf.equal(tf.shape(inputs.inputs_length)[0], 1), - lambda: self.recognize_single(inputs, max_tokens_per_frame=max_tokens_per_frame, **kwargs), - lambda: self.recognize_batch(inputs, **kwargs), - ) + if self._batch_size == 1: + return self.recognize_single(inputs, max_tokens_per_frame=max_tokens_per_frame, **kwargs) + return self.recognize_batch(inputs, **kwargs) def recognize_batch(self, inputs: schemas.PredictInput, **kwargs): """ @@ -485,9 +492,9 @@ def recognize_batch(self, inputs: schemas.PredictInput, **kwargs): # The current indices of the output of encoder, shape [B, 1] frame_indices = tf.zeros([batch_size, 1], dtype=tf.int32, name="frame_indices") # Previous predicted tokens, initially are blanks, shape [B, 1] - previous_tokens = inputs.previous_tokens or tf.ones([batch_size, 1], dtype=tf.int32, name="previous_tokens") * self.blank + previous_tokens = inputs.previous_tokens # Previous states of the prediction network, initially are zeros, shape [B, num_rnns, nstates, rnn_units] - previous_decoder_states = inputs.previous_decoder_states or self.predict_net.get_initial_state(batch_size) + previous_decoder_states = inputs.previous_decoder_states # Assumption that number of tokens can not exceed (2 * the size of output of encoder + 1), this is for static runs like TPU or TFLite max_tokens = max_frames * 2 + 1 # All of the tokens that are getting recognized, initially are blanks, shape [B, nframes * 2 + 1] @@ -564,7 +571,7 @@ def recognize_single(self, inputs: schemas.PredictInput, max_tokens_per_frame: i frame = tf.zeros([1, 1], dtype=tf.int32) nframes = encoded_length - previous_tokens = inputs.previous_tokens or tf.ones([1, 1], dtype=tf.int32) * self.blank + previous_tokens = inputs.previous_tokens token_index = tf.ones([], dtype=tf.int32) * -1 tokens = tf.TensorArray( dtype=tf.int32, @@ -581,7 +588,7 @@ def recognize_single(self, inputs: schemas.PredictInput, max_tokens_per_frame: i element_shape=tf.TensorShape([]), ) - previous_decoder_states = inputs.previous_decoder_states or self.predict_net.get_initial_state(1) + previous_decoder_states = inputs.previous_decoder_states def condition( _frame, @@ -815,8 +822,8 @@ def body( # -------------------------------- BEAM SEARCH ------------------------------------- - def recognize_beam(self, inputs: schemas.PredictInput, **kwargs): - return self.recognize(inputs=inputs, **kwargs) + def recognize_beam(self, inputs: schemas.PredictInput, beam_width: int = 10, **kwargs): + return self.recognize(inputs=inputs, **kwargs) # TODO: Implement beam search # def _perform_beam_search_batch( # self, diff --git a/tensorflow_asr/models/transducer/rnnt.py b/tensorflow_asr/models/transducer/rnnt.py index 3af7d06329..248d83a537 100644 --- a/tensorflow_asr/models/transducer/rnnt.py +++ b/tensorflow_asr/models/transducer/rnnt.py @@ -98,81 +98,5 @@ def __init__( self.time_reduction_factor = self.encoder.time_reduction_factor self.dmodel = encoder_dmodel - # def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): - # """Infer function for encoder (or encoders) - - # Args: - # features (tf.Tensor): features with shape [T, F, C] - # states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P] - - # Returns: - # tf.Tensor: output of encoders with shape [T, E] - # tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P] - # """ - # with tf.name_scope("encoder"): - # outputs = tf.expand_dims(features, axis=0) - # outputs, new_states = self.encoder.recognize(outputs, states) - # return tf.squeeze(outputs, axis=0), new_states - - # # -------------------------------- GREEDY ------------------------------------- - - # def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): - # """ - # Function to convert to tflite using greedy decoding (default streaming mode) - # Args: - # signal: tf.Tensor with shape [None] indicating a single audio signal - # predicted: last predicted character with shape [] - # encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] - # prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] - - # Return: - # transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 - # predicted: last predicted character with shape [] - # encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] - # prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] - # """ - # features = self.speech_featurizer.tf_extract(signal) - # encoded, new_encoder_states = self.encoder_inference(features, encoder_states) - # hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) - # transcript = self.text_featurizer.detokenize_unicode_points(hypothesis.prediction) - # return transcript, hypothesis.index, new_encoder_states, hypothesis.states - - # def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): - # features = self.speech_featurizer.tf_extract(signal) - # encoded, new_encoder_states = self.encoder_inference(features, encoder_states) - # hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) - # indices = self.text_featurizer.normalize_indices(hypothesis.prediction) - # upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] - - # num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) - # total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step - - # stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - # stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - # etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - # etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - # non_blank = tf.where(tf.not_equal(upoints, 0)) - # non_blank_transcript = tf.gather_nd(upoints, non_blank) - # non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - # non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - - # return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states - - # -------------------------------- TFLITE ------------------------------------- - - # def make_tflite_function( - # self, - # timestamp: bool = True, - # ): - # tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite - # return tf.function( - # tflite_func, - # input_signature=[ - # tf.TensorSpec([None], dtype=tf.float32), - # tf.TensorSpec([], dtype=tf.int32), - # tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), - # tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32), - # ], - # ) + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) diff --git a/tensorflow_asr/schemas.py b/tensorflow_asr/schemas.py index b3816bde9b..4504e3ae57 100644 --- a/tensorflow_asr/schemas.py +++ b/tensorflow_asr/schemas.py @@ -54,3 +54,8 @@ def TrainLabel(labels, labels_length): ("tokens", "next_tokens", "next_encoder_states", "next_decoder_states"), defaults=(None, None), ) +PredictOutputWithTranscript = collections.namedtuple( + "PredictOutputWithTranscript", + ("transcript", "tokens", "next_tokens", "next_encoder_states", "next_decoder_states"), + defaults=(None, None), +) diff --git a/tensorflow_asr/tokenizers.py b/tensorflow_asr/tokenizers.py index b2097ea48b..a807703647 100755 --- a/tensorflow_asr/tokenizers.py +++ b/tensorflow_asr/tokenizers.py @@ -114,7 +114,7 @@ def corpus_generator(cls, decoder_config: DecoderConfig): temp_lines = f.read().splitlines() for line in temp_lines[1:]: # Skip the header of tsv file data = line.split("\t", 2)[-1] # get only transcript - data = cls.normalize_text(data, decoder_config.normalization_form).numpy() + data = cls.normalize_text(data, decoder_config).numpy() yield data @property @@ -135,9 +135,12 @@ def reset_length(self): self.max_length = 0 @classmethod - def normalize_text(cls, text: tf.Tensor, normalization_form: str = "NFKC"): - text = tft.normalize_utf8(text, normalization_form) + def normalize_text(cls, text: tf.Tensor, decoder_config: DecoderConfig): + text = tf.strings.regex_replace(text, b"\xe2\x81\x87".decode("utf-8"), "") + text = tft.normalize_utf8(text, decoder_config.normalization_form) text = tf.strings.regex_replace(text, r"\p{Cc}|\p{Cf}", " ") + text = tf.strings.regex_replace(text, decoder_config.unknown_token, "") + text = tf.strings.regex_replace(text, decoder_config.pad_token, "") text = tf.strings.regex_replace(text, r" +", " ") text = tf.strings.lower(text, encoding="utf-8") text = tf.strings.strip(text) # remove trailing whitespace @@ -159,7 +162,7 @@ def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor: with tf.name_scope("normalize_indices"): minus_one = -1 * tf.ones_like(indices, dtype=tf.int32) blank_like = self.blank * tf.ones_like(indices, dtype=tf.int32) - return tf.where(indices == minus_one, blank_like, indices) + return tf.where(tf.equal(indices, minus_one), blank_like, indices) def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: """Prepand blank index for transducer models""" @@ -230,7 +233,7 @@ def write_vocab_file(filepath, vocab): return cls(decoder_config) def tokenize(self, text): - text = self.normalize_text(text, self.decoder_config.normalization_form) + text = self.normalize_text(text, self.decoder_config) text = tf.strings.unicode_split(text, "UTF-8") return self.tokenizer.lookup(text) @@ -244,9 +247,10 @@ def detokenize(self, indices: tf.Tensor) -> tf.Tensor: transcripts: tf.Tensor of dtype tf.string with dim [B] """ indices = self.normalize_indices(indices) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) tokens = self.detokenizer.lookup(indices) tokens = tf.strings.reduce_join(tokens, axis=-1) + tokens = self.normalize_text(tokens, self.decoder_config) return tokens @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) @@ -307,7 +311,7 @@ def build_from_corpus(cls, decoder_config: DecoderConfig): return cls(decoder_config) def tokenize(self, text: tf.Tensor) -> tf.Tensor: - text = self.normalize_text(text, self.decoder_config.normalization_form) + text = self.normalize_text(text, self.decoder_config) indices = self.tokenizer.tokenize(text) indices = tf.cast(indices, tf.int32) return indices @@ -321,12 +325,13 @@ def detokenize(self, indices: tf.Tensor) -> tf.Tensor: Returns: transcripts: tf.Tensor of dtype tf.string with dim [B] """ - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.bos_index)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.eos_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.bos_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.eos_index)) transcripts = self.tokenizer.detokenize(indices) - transcripts = tf.strings.regex_replace(transcripts, r" +", " ") + transcripts = self.normalize_text(transcripts, self.decoder_config) + # transcripts = tf.strings.regex_replace(transcripts, r" +", " ") return transcripts @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) @@ -399,7 +404,7 @@ def write_vocab_file(filepath, vocab): return cls(decoder_config) def tokenize(self, text: tf.Tensor) -> tf.Tensor: - text = self.normalize_text(text, self.decoder_config.normalization_form) + text = self.normalize_text(text, self.decoder_config) if self.decoder_config.keep_whitespace: text = tf.strings.regex_replace(text, " ", "| |") text = tf.strings.split(text, sep="|") @@ -417,10 +422,11 @@ def detokenize(self, indices: tf.Tensor) -> tf.Tensor: Returns: transcripts: tf.Tensor of dtype tf.string with dim [B] """ - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) transcripts = self.tokenizer.detokenize(indices) - transcripts = tf.strings.regex_replace(transcripts, r" +", " ") + transcripts = self.normalize_text(transcripts, self.decoder_config) + # transcripts = tf.strings.regex_replace(transcripts, r" +", " ") return transcripts @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) diff --git a/tensorflow_asr/utils/app_util.py b/tensorflow_asr/utils/app_util.py index de98d1eed8..5559e2c6ac 100644 --- a/tensorflow_asr/utils/app_util.py +++ b/tensorflow_asr/utils/app_util.py @@ -15,11 +15,12 @@ import jiwer import tensorflow as tf +import tensorflow_text as tf_text from tqdm import tqdm from tensorflow_asr.models.base_model import BaseModel from tensorflow_asr.tokenizers import Tokenizer -from tensorflow_asr.utils import file_util +from tensorflow_asr.utils import file_util, math_util logger = tf.get_logger() @@ -83,13 +84,24 @@ def convert_tflite( model: BaseModel, output: str, batch_size: int = 1, + beam_width: int = 0, ): - concrete_func = model.make_tflite_function(batch_size=batch_size).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + if not math_util.is_power_of_two(model.feature_extraction.nfft): + logger.error("NFFT must be power of 2 for TFLite conversion") + overwrite_nfft = input("Do you want to overwrite nfft to the nearest power of 2? (y/n): ") + if overwrite_nfft.lower() == "y": + model.feature_extraction.nfft = math_util.next_power_of_two(model.feature_extraction.nfft) + logger.info(f"Overwritten nfft to {model.feature_extraction.nfft}") + else: + raise ValueError("NFFT must be power of 2 for TFLite conversion") + + concrete_func = model.make_tflite_function(batch_size=batch_size, beam_width=beam_width).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func], trackable_obj=model) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops. ] + converter.allow_custom_ops = True tflite_model = converter.convert() output = file_util.preprocess_paths(output) diff --git a/tensorflow_asr/utils/math_util.py b/tensorflow_asr/utils/math_util.py index 65d912933d..ead5a13e18 100644 --- a/tensorflow_asr/utils/math_util.py +++ b/tensorflow_asr/utils/math_util.py @@ -200,7 +200,7 @@ def masked_fill( value=0, ): shape = shape_util.shape_list(tensor) - mask = tf.broadcast_to(mask, shape) + mask = tf.cast(tf.broadcast_to(mask, shape), dtype=tf.bool) values = tf.cast(tf.fill(shape, value), tensor.dtype) return tf.where(mask, tensor, values) @@ -278,3 +278,15 @@ def compute_time_length( with tf.name_scope("compute_time_length"): batch_size, time_length, *_ = shape_util.shape_list(tensor) return tf.cast(tf.repeat(time_length, batch_size, axis=0), dtype=dtype) + + +def is_power_of_two( + x: int, +): + return x != 0 and (x & (x - 1)) == 0 + + +def next_power_of_two( + x: int, +): + return 1 if x == 0 else 2 ** math.ceil(math.log2(x))