Skip to content

Commit

Permalink
Profile TorchServe Handler (preprocess vs inference vs post-process) (#…
Browse files Browse the repository at this point in the history
…2470)

* Profile TS Handler using ab tool

* Added an example

* Added an example

* handler class not needed

* Add model_yaml_config to MockCOntext

* remove unnecessary config

* based on review comments

* Added details on how to enable this

* Added details on how to enable this

* lint fix

* lint fix

* lint fix

---------

Co-authored-by: Geeta Chauhan <[email protected]>
  • Loading branch information
agunapal and chauhang committed Aug 24, 2023
1 parent d47b14d commit 03ad862
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 23 deletions.
22 changes: 22 additions & 0 deletions benchmarks/benchmark-ab.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"image": "",
"docker_runtime": "",
"backend_profiling": False,
"handler_profiling": False,
"generate_graphs": False,
"config_properties": "config.properties",
"inference_model_url": "predictions/benchmark",
Expand Down Expand Up @@ -95,6 +96,12 @@ def json_provider(file_path, cmd_name):
default=False,
help="Enable backend profiling using CProfile. Default False",
)
@click.option(
"--handler_profiling",
"-hp",
default=False,
help="Enable handler profiling. Default False",
)
@click.option(
"--generate_graphs",
"-gg",
Expand Down Expand Up @@ -143,6 +150,7 @@ def benchmark(
image,
docker_runtime,
backend_profiling,
handler_profiling,
config_properties,
inference_model_url,
report_location,
Expand All @@ -163,6 +171,7 @@ def benchmark(
"image": image,
"docker_runtime": docker_runtime,
"backend_profiling": backend_profiling,
"handler_profiling": handler_profiling,
"config_properties": config_properties,
"inference_model_url": inference_model_url,
"report_location": report_location,
Expand Down Expand Up @@ -469,13 +478,26 @@ def generate_report(warm_up_lines):
}


def update_metrics():
if execution_params["handler_profiling"]:
opt_metrics = {
"handler_preprocess.txt": "ts_handler_preprocess",
"handler_inference.txt": "ts_handler_inference",
"handler_postprocess.txt": "ts_handler_postprocess",
}
metrics.update(opt_metrics)
return metrics


def extract_metrics(warm_up_lines):
with open(execution_params["metric_log"]) as f:
lines = f.readlines()

click.secho(f"Dropping {warm_up_lines} warmup lines from log", fg="green")
lines = lines[warm_up_lines:]

metrics = update_metrics()

for k, v in metrics.items():
all_lines = []
pattern = re.compile(v)
Expand Down
43 changes: 29 additions & 14 deletions benchmarks/utils/gen_model_config_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import copy
import json
import os

import yaml


def main():

parser = argparse.ArgumentParser()

parser.add_argument(
Expand All @@ -22,6 +22,7 @@ def main():
arguments = parser.parse_args()
convert_yaml_to_json(arguments.input, arguments.output)


MODEL_CONFIG_KEY = {
"batch_size",
"batch_delay",
Expand All @@ -30,12 +31,18 @@ def main():
"concurrency",
"workers",
"input",
"processors"
"processors",
"handler_profiling",
}


def convert_yaml_to_json(yaml_file_path, output_dir):
print("convert_yaml_to_json yaml_file_path={}, output_dir={}".format(yaml_file_path, output_dir))
with open(yaml_file_path, 'r') as f:
print(
"convert_yaml_to_json yaml_file_path={}, output_dir={}".format(
yaml_file_path, output_dir
)
)
with open(yaml_file_path, "r") as f:
yaml_dict = yaml.safe_load(f)

for model, config in yaml_dict.items():
Expand All @@ -58,10 +65,9 @@ def convert_yaml_to_json(yaml_file_path, output_dir):
batch_worker_list = []
for batch_size in batch_size_list:
for workers in workers_list:
batch_worker_list.append({
"batch_size" : batch_size,
"workers" : workers
})
batch_worker_list.append(
{"batch_size": batch_size, "workers": workers}
)

benchmark_configs = []
for batch_worker in batch_worker_list:
Expand All @@ -72,25 +78,34 @@ def convert_yaml_to_json(yaml_file_path, output_dir):
for bConfig in benchmark_configs:
for i in range(len(processors)):
if type(processors[i]) is str:
path = '{}/{}'.format(output_dir, processors[i])
path = "{}/{}".format(output_dir, processors[i])
if not os.path.isdir(path):
continue

benchmark_config_file = '{}/{}_w{}_b{}.json'\
.format(path, model_name, bConfig["workers"], bConfig["batch_size"])
benchmark_config_file = "{}/{}_w{}_b{}.json".format(
path,
model_name,
bConfig["workers"],
bConfig["batch_size"],
)
with open(benchmark_config_file, "w") as outfile:
json.dump(bConfig, outfile, indent=4)
elif type(processors[i]) is dict:
path = '{}/gpu'.format(output_dir)
path = "{}/gpu".format(output_dir)
if not os.path.isdir(path):
continue

bConfig["gpus"] = processors[i]["gpus"]
benchmark_config_file = '{}/{}_w{}_b{}.json'\
.format(path, model_name, bConfig["workers"], bConfig["batch_size"])
benchmark_config_file = "{}/{}_w{}_b{}.json".format(
path,
model_name,
bConfig["workers"],
bConfig["batch_size"],
)
with open(benchmark_config_file, "w") as outfile:
json.dump(bConfig, outfile, indent=4)
del bConfig["gpus"]


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions examples/benchmarking/resnet50/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

# Benchmark ResNet50 and profile the detailed split of PredictionTime

This example shows how to run the benchmark ab tool on ResNet50 and identify the time spent on preprocess, inference and postprocess

Change directory to the root of `serve`
Ex: if `serve` is under `/home/ubuntu`, change directory to `/home/ubuntu/serve`


## Download the weights

```
wget https://download.pytorch.org/models/resnet50-11ad3fa6.pth
```

### Create model archive

To enable profiling of TorchServe Handler, add the following config in model-config.yaml
```
handler:
profile: true
```

```
torch-model-archiver --model-name resnet-50 --version 1.0 --model-file ./examples/benchmarking/resnet50/model.py --serialized-file resnet50-11ad3fa6.pth --handler image_classifier --extra-files ./examples/image_classifier/index_to_name.json --config-file ./examples/benchmarking/resnet50/model-config.yaml
mkdir model_store
mv resnet-50.mar model_store/.
```

### Install dependencies for benchmark tool

```
sudo apt-get update -y
sudo apt-get install -y apache2-utils
pip install -r benchmarks/requirements-ab.txt
```

### Run ab tool for benchmarking

```
python benchmarks/auto_benchmark.py --input examples/benchmarking/resnet50/benchmark_profile.yaml --skip true
```

This generates the report under `/tmp/ts_benchmarking/report.md`
16 changes: 16 additions & 0 deletions examples/benchmarking/resnet50/benchmark_profile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Torchserve version is to be installed. It can be one of the options
# - branch : "master"
# - nightly: "2022.3.16"
# - release: "0.5.3"
# Nightly build will be installed if "ts_version" is not specifiged
#ts_version:
# branch: &ts_version "master"

# a list of model configure yaml files defined in benchmarks/models_config
# or a list of model configure yaml files with full path
models:
- "/home/ubuntu/serve/examples/benchmarking/resnet50/resnet50.yaml"

# benchmark on "cpu" or "gpu".
# "cpu" is set if "hardware" is not specified
hardware: &hardware "gpu"
2 changes: 2 additions & 0 deletions examples/benchmarking/resnet50/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
handler:
profile: true
6 changes: 6 additions & 0 deletions examples/benchmarking/resnet50/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torchvision.models.resnet import Bottleneck, ResNet


class ImageClassifier(ResNet):
def __init__(self):
super(ImageClassifier, self).__init__(Bottleneck, [3, 4, 6, 3])
24 changes: 24 additions & 0 deletions examples/benchmarking/resnet50/resnet50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
resnet50:
eager_mode:
benchmark_engine: "ab"
url: "file:///home/ubuntu/serve/model_store/resnet-50.mar"
workers:
- 4
batch_delay: 100
batch_size:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
requests: 10000
concurrency: 100
input: "./examples/image_classifier/kitten.jpg"
handler_profiling: true
exec_env: "local"
processors:
- "cpu"
- "gpus": "all"
66 changes: 66 additions & 0 deletions ts/handler_utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Decorator for timing handler methods
Use this decorator to compute the execution time for your preprocesss, inference and
postprocess methods.
By default this feature is not enabled.
To enable this, add the following section in your model-config.yaml file
handler:
profile: true
An example of running benchmarks with the profiling enabled is in
https://github.com/pytorch/serve/tree/master/examples/benchmarking/resnet50
"""

import time

import torch


def timed(func):
def wrap_func(self, *args, **kwargs):
# Measure time if config specified in model_yaml_config
if (
"handler" in self.context.model_yaml_config
and "profile" in self.context.model_yaml_config["handler"]
):
if self.context.model_yaml_config["handler"]["profile"]:
# Measure start time
if torch.cuda.is_available():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.time()

result = func(self, *args, **kwargs)

# Measure end time
if torch.cuda.is_available():
end.record()
torch.cuda.synchronize()
else:
end = time.time()

# Measure time taken to execute the function in miliseconds
if torch.cuda.is_available():
duration = start.elapsed_time(end)
else:
duration = (end - start) * 1000

# Add metrics for profiling
metrics = self.context.metrics
metrics.add_time("ts_handler_" + func.__name__, duration)
else:
# If profile config specified in model_yaml_config is False
result = func(self, *args, **kwargs)
else:
# If no profile config specified in model_yaml_config
result = func(self, *args, **kwargs)

return result

return wrap_func
8 changes: 7 additions & 1 deletion ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from pkg_resources import packaging

from ts.handler_utils.timer import timed

from ..utils.util import (
check_valid_pt2_backend,
list_classes_from_module,
Expand Down Expand Up @@ -77,7 +79,8 @@
ONNX_AVAILABLE = False

try:
import torch_tensorrt
import torch_tensorrt # nopycln: import

logger.info("Torch TensorRT enabled")
except ImportError:
logger.warning("Torch TensorRT not enabled")
Expand Down Expand Up @@ -265,6 +268,7 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model.load_state_dict(state_dict)
return model

@timed
def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
Expand All @@ -279,6 +283,7 @@ def preprocess(self, data):

return torch.as_tensor(data, device=self.device)

@timed
def inference(self, data, *args, **kwargs):
"""
The Inference Function is used to make a prediction call on the given input request.
Expand All @@ -296,6 +301,7 @@ def inference(self, data, *args, **kwargs):
results = self.model(marshalled_data, *args, **kwargs)
return results

@timed
def postprocess(self, data):
"""
The post process function makes use of the output from the inference and converts into a
Expand Down
Loading

0 comments on commit 03ad862

Please sign in to comment.