-
Notifications
You must be signed in to change notification settings - Fork 859
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
asr speech recognition example (#2047)
* asr speech recognition example * added details in README according to maaquib * Apply suggestions from code review --------- Co-authored-by: Aaqib <[email protected]> Co-authored-by: Ankith Gunapal <[email protected]> Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Mark Saroufim <[email protected]>
- Loading branch information
1 parent
656a30d
commit 242895c
Showing
11 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
|
||
CONTAINER=pytorch/torchserve:0.8.2-cpu | ||
|
||
docker run --rm \ | ||
-v $PWD:/home/model-server \ | ||
--entrypoint /bin/bash \ | ||
--workdir /home/model-server \ | ||
$CONTAINER \ | ||
-c \ | ||
"pip install -r requirements.txt && python save_jit_model.py | ||
" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#!/bin/bash | ||
|
||
CONTAINER=pytorch/torchserve:0.8.2-cpu | ||
# CONTAINER=763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.11.0-cpu-py38-ubuntu20.04-e3 | ||
|
||
# create mar | ||
docker run --rm \ | ||
-v $PWD:/home/model-server \ | ||
--entrypoint /bin/bash \ | ||
--workdir /home/model-server \ | ||
$CONTAINER \ | ||
-c \ | ||
"torch-model-archiver \ | ||
--model-name rnnt \ | ||
--version 1.0 \ | ||
--serialized-file decoder_jit.pt \ | ||
--handler handler.py \ | ||
--extra-files 1089-134686.trans.txt \ | ||
--requirements-file requirements.txt \ | ||
--force \ | ||
&& mkdir -p model-store \ | ||
&& mv rnnt.mar model-store/ | ||
" | ||
|
||
# serve; /home/model-server/config.properties has pre-defined model-store location | ||
docker run --rm --network host \ | ||
-p 8080:8080 -p 8081:8081 -p 8082:8082 -p 7070:7070 -p 7071:7071 \ | ||
-v $PWD:/home/model-server \ | ||
$CONTAINER |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
# torchserve --start --ncs --ts-config /home/model-server/config.properties | ||
|
||
curl -X POST "http://localhost:8081/models?url=rnnt.mar"; | ||
curl -X PUT "http://localhost:8081/models/rnnt?min_worker=1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import requests | ||
|
||
url = "http://127.0.0.1:8080/predictions/rnnt" | ||
|
||
with open('1089-134686-0000.wav', 'rb') as f: | ||
bytes = f.read() | ||
|
||
r = requests.post(url, data = bytes) | ||
|
||
print(r.text) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1089-134686-0000 HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
### ASR (Automated Speech Recognition) Example | ||
|
||
In this example we use torchserve to serve a ASR model that convert wav to text. There are four steps in this process. First we download a pretrained emformer model and save it to JIT format; Second we start model server, create the model archive; Third we configure the model server with 1 worker; Last we send a wav file to the model endpoint to get text prediction. | ||
|
||
#### Steps to run: | ||
- 1. Save asr model to jit format. | ||
```bash | ||
./00_save_jit_model.sh | ||
``` | ||
- 2. Create model archive | ||
```bash | ||
./01_create_model_archive.sh | ||
|
||
output: | ||
2023-01-10T20:46:39,660 [INFO ] pool-3-thread-2 TS_METRICS - MemoryUtilization.Percent:3.2|Level:Host|hostname:ip-172-31-15-90,timestamp:1673383599 | ||
``` | ||
- 3. Configure model server. register model and add workers. | ||
```bash | ||
./02_configure_server.sh | ||
|
||
Output: | ||
{ | ||
"status": "Model \"rnnt\" Version: 1.0 registered with 0 initial workers. Use scale workers API to add workers for the model." | ||
} | ||
{ | ||
"status": "Processing worker updates..." | ||
} | ||
|
||
``` | ||
|
||
- 4. Get prediction results | ||
``` | ||
python3 03_predict.py | ||
output: | ||
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fat and sauce | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
inference_address=http://0.0.0.0:8080 | ||
management_address=http://0.0.0.0:8081 | ||
metrics_address=http://0.0.0.0:8082 | ||
number_of_netty_threads=32 | ||
job_queue_size=1000 | ||
model_store=/home/model-server/model-store | ||
workflow_store=/home/model-server/wf-store | ||
install_py_dep_per_model=true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import json | ||
import os | ||
import tempfile | ||
|
||
import torch | ||
import torchaudio | ||
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH | ||
|
||
|
||
class ModelHandler(object): | ||
""" | ||
A custom model handler implementation. | ||
""" | ||
def __init__(self): | ||
self._context = None | ||
self.initialized = False | ||
self.model = None | ||
self.device = None | ||
|
||
def initialize(self, context): | ||
""" | ||
Invoke by torchserve for loading a model | ||
:param context: context contains model server system properties | ||
:return: | ||
""" | ||
|
||
# load the model | ||
self.manifest = context.manifest | ||
|
||
properties = context.system_properties | ||
model_dir = properties.get("model_dir") | ||
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") | ||
|
||
# Read model serialize/pt file | ||
serialized_file = self.manifest['model']['serializedFile'] | ||
model_pt_path = os.path.join(model_dir, serialized_file) | ||
if not os.path.isfile(model_pt_path): | ||
raise RuntimeError("Missing the model.pt file") | ||
|
||
self.model = torch.jit.load(model_pt_path) | ||
|
||
self.initialized = True | ||
|
||
|
||
def handle(self, data, context): | ||
""" | ||
Invoke by TorchServe for prediction request. | ||
Do pre-processing of data, prediction using model and postprocessing of prediciton output | ||
:param data: Input data for prediction | ||
:param context: Initial context contains model server system properties. | ||
:return: prediction output | ||
""" | ||
if isinstance(data, list): | ||
data = data[0] | ||
data = data.get("data") or data.get("body") | ||
|
||
# print('55-------', type(data)) # should be bytearray | ||
|
||
fp = tempfile.TemporaryFile() | ||
fp.write(data) | ||
fp.seek(0) | ||
|
||
waveform, sample_rate = torchaudio.load(fp) | ||
waveform = waveform.squeeze() | ||
|
||
feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor() | ||
|
||
decoder = self.model | ||
|
||
token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor() | ||
|
||
with torch.no_grad(): | ||
features, length = feature_extractor(waveform) | ||
hypos = decoder(features, length, 1) | ||
|
||
# return batch_size = 1 | ||
return [token_processor(hypos[0].tokens)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
--extra-index-url https://download.pytorch.org/whl/cpu | ||
torchaudio | ||
sentencepiece |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch | ||
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH | ||
|
||
|
||
jit_model = 'decoder_jit.pt' | ||
model = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder() | ||
|
||
model.eval() | ||
|
||
mj = torch.jit.script(model) | ||
|
||
mj.save(jit_model) | ||
|