Skip to content

Commit

Permalink
asr speech recognition example (#2047)
Browse files Browse the repository at this point in the history
* asr speech recognition example

* added details in README according to  maaquib

* Apply suggestions from code review

---------

Co-authored-by: Aaqib <[email protected]>
Co-authored-by: Ankith Gunapal <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
5 people authored Aug 31, 2023
1 parent 656a30d commit 242895c
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 0 deletions.
12 changes: 12 additions & 0 deletions examples/asr_rnnt_emformer/00_save_jit_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

CONTAINER=pytorch/torchserve:0.8.2-cpu

docker run --rm \
-v $PWD:/home/model-server \
--entrypoint /bin/bash \
--workdir /home/model-server \
$CONTAINER \
-c \
"pip install -r requirements.txt && python save_jit_model.py
"
29 changes: 29 additions & 0 deletions examples/asr_rnnt_emformer/01_create_model_archive.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash

CONTAINER=pytorch/torchserve:0.8.2-cpu
# CONTAINER=763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.11.0-cpu-py38-ubuntu20.04-e3

# create mar
docker run --rm \
-v $PWD:/home/model-server \
--entrypoint /bin/bash \
--workdir /home/model-server \
$CONTAINER \
-c \
"torch-model-archiver \
--model-name rnnt \
--version 1.0 \
--serialized-file decoder_jit.pt \
--handler handler.py \
--extra-files 1089-134686.trans.txt \
--requirements-file requirements.txt \
--force \
&& mkdir -p model-store \
&& mv rnnt.mar model-store/
"

# serve; /home/model-server/config.properties has pre-defined model-store location
docker run --rm --network host \
-p 8080:8080 -p 8081:8081 -p 8082:8082 -p 7070:7070 -p 7071:7071 \
-v $PWD:/home/model-server \
$CONTAINER
6 changes: 6 additions & 0 deletions examples/asr_rnnt_emformer/02_configure_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

# torchserve --start --ncs --ts-config /home/model-server/config.properties

curl -X POST "http://localhost:8081/models?url=rnnt.mar";
curl -X PUT "http://localhost:8081/models/rnnt?min_worker=1"
10 changes: 10 additions & 0 deletions examples/asr_rnnt_emformer/03_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import requests

url = "http://127.0.0.1:8080/predictions/rnnt"

with open('1089-134686-0000.wav', 'rb') as f:
bytes = f.read()

r = requests.post(url, data = bytes)

print(r.text)
Binary file added examples/asr_rnnt_emformer/1089-134686-0000.wav
Binary file not shown.
1 change: 1 addition & 0 deletions examples/asr_rnnt_emformer/1089-134686.trans.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1089-134686-0000 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE
37 changes: 37 additions & 0 deletions examples/asr_rnnt_emformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
### ASR (Automated Speech Recognition) Example

In this example we use torchserve to serve a ASR model that convert wav to text. There are four steps in this process. First we download a pretrained emformer model and save it to JIT format; Second we start model server, create the model archive; Third we configure the model server with 1 worker; Last we send a wav file to the model endpoint to get text prediction.

#### Steps to run:
- 1. Save asr model to jit format.
```bash
./00_save_jit_model.sh
```
- 2. Create model archive
```bash
./01_create_model_archive.sh

output:
2023-01-10T20:46:39,660 [INFO ] pool-3-thread-2 TS_METRICS - MemoryUtilization.Percent:3.2|Level:Host|hostname:ip-172-31-15-90,timestamp:1673383599
```
- 3. Configure model server. register model and add workers.
```bash
./02_configure_server.sh

Output:
{
"status": "Model \"rnnt\" Version: 1.0 registered with 0 initial workers. Use scale workers API to add workers for the model."
}
{
"status": "Processing worker updates..."
}

```

- 4. Get prediction results
```
python3 03_predict.py
output:
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fat and sauce
```
8 changes: 8 additions & 0 deletions examples/asr_rnnt_emformer/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
number_of_netty_threads=32
job_queue_size=1000
model_store=/home/model-server/model-store
workflow_store=/home/model-server/wf-store
install_py_dep_per_model=true
77 changes: 77 additions & 0 deletions examples/asr_rnnt_emformer/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
import os
import tempfile

import torch
import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH


class ModelHandler(object):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.device = None

def initialize(self, context):
"""
Invoke by torchserve for loading a model
:param context: context contains model server system properties
:return:
"""

# load the model
self.manifest = context.manifest

properties = context.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

# Read model serialize/pt file
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")

self.model = torch.jit.load(model_pt_path)

self.initialized = True


def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
if isinstance(data, list):
data = data[0]
data = data.get("data") or data.get("body")

# print('55-------', type(data)) # should be bytearray

fp = tempfile.TemporaryFile()
fp.write(data)
fp.seek(0)

waveform, sample_rate = torchaudio.load(fp)
waveform = waveform.squeeze()

feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()

decoder = self.model

token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()

with torch.no_grad():
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 1)

# return batch_size = 1
return [token_processor(hypos[0].tokens)]
3 changes: 3 additions & 0 deletions examples/asr_rnnt_emformer/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torchaudio
sentencepiece
13 changes: 13 additions & 0 deletions examples/asr_rnnt_emformer/save_jit_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH


jit_model = 'decoder_jit.pt'
model = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()

model.eval()

mj = torch.jit.script(model)

mj.save(jit_model)

0 comments on commit 242895c

Please sign in to comment.