Skip to content

Commit

Permalink
TorchServe linux-aarch64 experimental support (#3071)
Browse files Browse the repository at this point in the history
* Changes for building TorchServe on linux aarch64

* Changes for building TorchServe on linux aarch64

* Added an example for linux aarch64

* Doc update for linux aarch64

* Doc update for linux aarch64

* Doc update for linux aarch64

* removed torchtext for aarch64

* lint failure

* lint failure

* Build conda binaries

* Build conda binaries

* resolving merge conflicts

* resolving merge conflicts

* update documentation

* review comments

* Updated based on review comments

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people committed May 3, 2024
1 parent a69e561 commit 5c1682a
Show file tree
Hide file tree
Showing 18 changed files with 202 additions and 5 deletions.
8 changes: 7 additions & 1 deletion binaries/conda/build_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
PACKAGES = ["torchserve", "torch-model-archiver", "torch-workflow-archiver"]

# conda convert supported platforms https://docs.conda.io/projects/conda-build/en/stable/resources/commands/conda-convert.html
PLATFORMS = ["linux-64", "osx-64", "win-64", "osx-arm64"] # Add a new platform here
PLATFORMS = [
"linux-64",
"osx-64",
"win-64",
"osx-arm64",
"linux-aarch64",
] # Add a new platform here

if os.name == "nt":
# Assumes miniconda is installed in windows
Expand Down
29 changes: 29 additions & 0 deletions docs/linux_aarch64.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# TorchServe on linux aarch64 - Experimental

TorchServe has been tested to be working on linux aarch64 for some of the examples.
- Tested this on Amazon Graviton 3 instance(m7g.4x.large)

## Installation

Currently installation from PyPi or installing from source works

```
python ts_scripts/install_dependencies.py
pip install torchserve torch-model-archiver torch-workflow-archiver
```

## Optimizations

You can also enable this optimizations for Graviton 3 to get an improved performance. More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/)
```
export DNNL_DEFAULT_FPMATH_MODE=BF16
export LRU_CACHE_CAPACITY=1024
```

## Example

This [example](https://github.com/pytorch/serve/tree/master/examples/text_to_speech_synthesizer/SpeechT5) on Text to Speech synthesis was verified to be working on Graviton 3

## To Dos
- CI
- Regression tests
50 changes: 50 additions & 0 deletions examples/text_to_speech_synthesizer/SpeechT5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Text to Speech synthesis with SpeechT5

This is an example showing text to speech synthesis using SpeechT5 model. This has been verified to work on (linux-aarch64) Graviton 3 instance

While running this model on `linux-aarch64`, you can enable these optimizations

```
export DNNL_DEFAULT_FPMATH_MODE=BF16
export LRU_CACHE_CAPACITY=1024
```
More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/)


## Pre-requisites
```
chmod +x setup.sh
./setup.sh
```

## Download model

This saves the model artifacts to `model_artifacts` directory
```
huggingface-cli login
python download_model.py
```

## Create model archiver

```
mkdir model_store
torch-model-archiver --model-name SpeechT5-TTS --version 1.0 --handler text_to_speech_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
mv model_artifacts/* model_store/SpeechT5-TTS/
```

## Start TorchServe

```
torchserve --start --ncs --model-store model_store --models SpeechT5-TTS
```

## Send Inference request

```
curl http://127.0.0.1:8080/predictions/SpeechT5-TTS -T sample_input.txt -o speech.wav
```

This generates an audio file `speech.wav` corresponding to the text in `sample_input.txt`
14 changes: 14 additions & 0 deletions examples/text_to_speech_synthesizer/SpeechT5/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from datasets import load_dataset
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor

processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")

model.save_pretrained(save_directory="model_artifacts/model")
processor.save_pretrained(save_directory="model_artifacts/processor")
vocoder.save_pretrained(save_directory="model_artifacts/vocoder")
embeddings_dataset.save_to_disk("model_artifacts/speaker_embeddings")
print("Save model artifacts to directory model_artifacts")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
minWorkers: 1
maxWorkers: 1
handler:
model: "model"
vocoder: "vocoder"
processor: "processor"
speaker_embeddings: "speaker_embeddings"
output_dir: "/tmp"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"I love San Francisco"
6 changes: 6 additions & 0 deletions examples/text_to_speech_synthesizer/SpeechT5/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

# Needed for soundfile
sudo apt install libsndfile1 -y

pip install --upgrade transformers sentencepiece datasets[audio] soundfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os
import uuid

import soundfile as sf
import torch
from datasets import load_from_disk
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class SpeechT5_TTS(BaseHandler):
def __init__(self):
self.model = None
self.processor = None
self.vocoder = None
self.speaker_embeddings = None
self.output_dir = "/tmp"

def initialize(self, ctx):
properties = ctx.system_properties
model_dir = properties.get("model_dir")

processor = ctx.model_yaml_config["handler"]["processor"]
model = ctx.model_yaml_config["handler"]["model"]
vocoder = ctx.model_yaml_config["handler"]["vocoder"]
embeddings_dataset = ctx.model_yaml_config["handler"]["speaker_embeddings"]
self.output_dir = ctx.model_yaml_config["handler"]["output_dir"]

self.processor = SpeechT5Processor.from_pretrained(processor)
self.model = SpeechT5ForTextToSpeech.from_pretrained(model)
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder)

# load xvector containing speaker's voice characteristics from a dataset
embeddings_dataset = load_from_disk(embeddings_dataset)
self.speaker_embeddings = torch.tensor(
embeddings_dataset[7306]["xvector"]
).unsqueeze(0)

def preprocess(self, requests):
assert len(requests) == 1, "This is currently supported with batch_size=1"
req_data = requests[0]

input_data = req_data.get("data") or req_data.get("body")

if isinstance(input_data, (bytes, bytearray)):
input_data = input_data.decode("utf-8")

inputs = self.processor(text=input_data, return_tensors="pt")

return inputs

def inference(self, inputs):
output = self.model.generate_speech(
inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder
)
return output

def postprocess(self, inference_output):
path = self.output_dir + "/{}.wav".format(uuid.uuid4().hex)
sf.write(path, inference_output.numpy(), samplerate=16000)
with open(path, "rb") as output:
data = output.read()
os.remove(path)
return [data]
2 changes: 1 addition & 1 deletion requirements/developer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pre-commit==3.3.2
twine==4.0.2
mypy==1.3.0
torchpippy==0.1.1
intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin'
intel_extension_for_pytorch==2.2.0; sys_platform != 'win32' and sys_platform != 'darwin' and platform_machine != 'aarch64'
onnxruntime==1.17.1
googleapis-common-protos
onnx==1.16.0
Expand Down
6 changes: 6 additions & 0 deletions requirements/torch_linux_aarch64.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
--extra-index-url https://download.pytorch.org/whl/cpu
-r torch_common.txt
torch==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64'
torchvision==0.17.1; sys_platform == 'linux' and platform_machine == 'aarch64'
torchaudio==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64'
11 changes: 8 additions & 3 deletions ts_scripts/install_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ def install_torch_packages(self, cuda_version):
f"{sys.executable} -m pip install -U -r {torch_neuronx_requirements_file}"
)
else:
os.system(
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt"
)
if platform.machine() == "aarch64":
os.system(
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}_{platform.machine()}.txt"
)
else:
os.system(
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt"
)

def install_python_packages(self, cuda_version, requirements_file_path, nightly):
check = "where" if platform.system() == "Windows" else "which"
Expand Down
4 changes: 4 additions & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,10 @@ libomp
rpath
venv
TorchInductor
Graviton
aarch
linux
SpeechT
Pytests
deviceType
XGBoost
Expand Down

0 comments on commit 5c1682a

Please sign in to comment.