diff --git a/benchmarks/benchmark-ab.py b/benchmarks/benchmark-ab.py index ebe48ea50a..a2a609c4e9 100644 --- a/benchmarks/benchmark-ab.py +++ b/benchmarks/benchmark-ab.py @@ -30,6 +30,7 @@ "image": "", "docker_runtime": "", "backend_profiling": False, + "handler_profiling": False, "generate_graphs": False, "config_properties": "config.properties", "inference_model_url": "predictions/benchmark", @@ -95,6 +96,12 @@ def json_provider(file_path, cmd_name): default=False, help="Enable backend profiling using CProfile. Default False", ) +@click.option( + "--handler_profiling", + "-hp", + default=False, + help="Enable handler profiling. Default False", +) @click.option( "--generate_graphs", "-gg", @@ -143,6 +150,7 @@ def benchmark( image, docker_runtime, backend_profiling, + handler_profiling, config_properties, inference_model_url, report_location, @@ -163,6 +171,7 @@ def benchmark( "image": image, "docker_runtime": docker_runtime, "backend_profiling": backend_profiling, + "handler_profiling": handler_profiling, "config_properties": config_properties, "inference_model_url": inference_model_url, "report_location": report_location, @@ -469,6 +478,17 @@ def generate_report(warm_up_lines): } +def update_metrics(): + if execution_params["handler_profiling"]: + opt_metrics = { + "handler_preprocess.txt": "ts_handler_preprocess", + "handler_inference.txt": "ts_handler_inference", + "handler_postprocess.txt": "ts_handler_postprocess", + } + metrics.update(opt_metrics) + return metrics + + def extract_metrics(warm_up_lines): with open(execution_params["metric_log"]) as f: lines = f.readlines() @@ -476,6 +496,8 @@ def extract_metrics(warm_up_lines): click.secho(f"Dropping {warm_up_lines} warmup lines from log", fg="green") lines = lines[warm_up_lines:] + metrics = update_metrics() + for k, v in metrics.items(): all_lines = [] pattern = re.compile(v) diff --git a/benchmarks/utils/gen_model_config_json.py b/benchmarks/utils/gen_model_config_json.py index 6b963e13f0..b9534934a1 100644 --- a/benchmarks/utils/gen_model_config_json.py +++ b/benchmarks/utils/gen_model_config_json.py @@ -2,11 +2,11 @@ import copy import json import os + import yaml def main(): - parser = argparse.ArgumentParser() parser.add_argument( @@ -22,6 +22,7 @@ def main(): arguments = parser.parse_args() convert_yaml_to_json(arguments.input, arguments.output) + MODEL_CONFIG_KEY = { "batch_size", "batch_delay", @@ -30,12 +31,18 @@ def main(): "concurrency", "workers", "input", - "processors" + "processors", + "handler_profiling", } + def convert_yaml_to_json(yaml_file_path, output_dir): - print("convert_yaml_to_json yaml_file_path={}, output_dir={}".format(yaml_file_path, output_dir)) - with open(yaml_file_path, 'r') as f: + print( + "convert_yaml_to_json yaml_file_path={}, output_dir={}".format( + yaml_file_path, output_dir + ) + ) + with open(yaml_file_path, "r") as f: yaml_dict = yaml.safe_load(f) for model, config in yaml_dict.items(): @@ -58,10 +65,9 @@ def convert_yaml_to_json(yaml_file_path, output_dir): batch_worker_list = [] for batch_size in batch_size_list: for workers in workers_list: - batch_worker_list.append({ - "batch_size" : batch_size, - "workers" : workers - }) + batch_worker_list.append( + {"batch_size": batch_size, "workers": workers} + ) benchmark_configs = [] for batch_worker in batch_worker_list: @@ -72,25 +78,34 @@ def convert_yaml_to_json(yaml_file_path, output_dir): for bConfig in benchmark_configs: for i in range(len(processors)): if type(processors[i]) is str: - path = '{}/{}'.format(output_dir, processors[i]) + path = "{}/{}".format(output_dir, processors[i]) if not os.path.isdir(path): continue - benchmark_config_file = '{}/{}_w{}_b{}.json'\ - .format(path, model_name, bConfig["workers"], bConfig["batch_size"]) + benchmark_config_file = "{}/{}_w{}_b{}.json".format( + path, + model_name, + bConfig["workers"], + bConfig["batch_size"], + ) with open(benchmark_config_file, "w") as outfile: json.dump(bConfig, outfile, indent=4) elif type(processors[i]) is dict: - path = '{}/gpu'.format(output_dir) + path = "{}/gpu".format(output_dir) if not os.path.isdir(path): continue bConfig["gpus"] = processors[i]["gpus"] - benchmark_config_file = '{}/{}_w{}_b{}.json'\ - .format(path, model_name, bConfig["workers"], bConfig["batch_size"]) + benchmark_config_file = "{}/{}_w{}_b{}.json".format( + path, + model_name, + bConfig["workers"], + bConfig["batch_size"], + ) with open(benchmark_config_file, "w") as outfile: json.dump(bConfig, outfile, indent=4) del bConfig["gpus"] + if __name__ == "__main__": main() diff --git a/examples/benchmarking/resnet50/README.md b/examples/benchmarking/resnet50/README.md new file mode 100644 index 0000000000..39721f6954 --- /dev/null +++ b/examples/benchmarking/resnet50/README.md @@ -0,0 +1,45 @@ + +# Benchmark ResNet50 and profile the detailed split of PredictionTime + +This example shows how to run the benchmark ab tool on ResNet50 and identify the time spent on preprocess, inference and postprocess + +Change directory to the root of `serve` +Ex: if `serve` is under `/home/ubuntu`, change directory to `/home/ubuntu/serve` + + +## Download the weights + +``` +wget https://download.pytorch.org/models/resnet50-11ad3fa6.pth +``` + +### Create model archive + +To enable profiling of TorchServe Handler, add the following config in model-config.yaml +``` +handler: + profile: true +``` + +``` +torch-model-archiver --model-name resnet-50 --version 1.0 --model-file ./examples/benchmarking/resnet50/model.py --serialized-file resnet50-11ad3fa6.pth --handler image_classifier --extra-files ./examples/image_classifier/index_to_name.json --config-file ./examples/benchmarking/resnet50/model-config.yaml + +mkdir model_store +mv resnet-50.mar model_store/. +``` + +### Install dependencies for benchmark tool + +``` +sudo apt-get update -y +sudo apt-get install -y apache2-utils +pip install -r benchmarks/requirements-ab.txt +``` + +### Run ab tool for benchmarking + +``` +python benchmarks/auto_benchmark.py --input examples/benchmarking/resnet50/benchmark_profile.yaml --skip true +``` + +This generates the report under `/tmp/ts_benchmarking/report.md` diff --git a/examples/benchmarking/resnet50/benchmark_profile.yaml b/examples/benchmarking/resnet50/benchmark_profile.yaml new file mode 100644 index 0000000000..eb2d57e204 --- /dev/null +++ b/examples/benchmarking/resnet50/benchmark_profile.yaml @@ -0,0 +1,16 @@ +# Torchserve version is to be installed. It can be one of the options +# - branch : "master" +# - nightly: "2022.3.16" +# - release: "0.5.3" +# Nightly build will be installed if "ts_version" is not specifiged +#ts_version: +# branch: &ts_version "master" + +# a list of model configure yaml files defined in benchmarks/models_config +# or a list of model configure yaml files with full path +models: + - "/home/ubuntu/serve/examples/benchmarking/resnet50/resnet50.yaml" + +# benchmark on "cpu" or "gpu". +# "cpu" is set if "hardware" is not specified +hardware: &hardware "gpu" diff --git a/examples/benchmarking/resnet50/model-config.yaml b/examples/benchmarking/resnet50/model-config.yaml new file mode 100644 index 0000000000..a8cbf248c4 --- /dev/null +++ b/examples/benchmarking/resnet50/model-config.yaml @@ -0,0 +1,2 @@ +handler: + profile: true diff --git a/examples/benchmarking/resnet50/model.py b/examples/benchmarking/resnet50/model.py new file mode 100644 index 0000000000..ac61782d3a --- /dev/null +++ b/examples/benchmarking/resnet50/model.py @@ -0,0 +1,6 @@ +from torchvision.models.resnet import Bottleneck, ResNet + + +class ImageClassifier(ResNet): + def __init__(self): + super(ImageClassifier, self).__init__(Bottleneck, [3, 4, 6, 3]) diff --git a/examples/benchmarking/resnet50/resnet50.yaml b/examples/benchmarking/resnet50/resnet50.yaml new file mode 100644 index 0000000000..2f97e0a8ca --- /dev/null +++ b/examples/benchmarking/resnet50/resnet50.yaml @@ -0,0 +1,24 @@ +--- +resnet50: + eager_mode: + benchmark_engine: "ab" + url: "file:///home/ubuntu/serve/model_store/resnet-50.mar" + workers: + - 4 + batch_delay: 100 + batch_size: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 + - 64 + requests: 10000 + concurrency: 100 + input: "./examples/image_classifier/kitten.jpg" + handler_profiling: true + exec_env: "local" + processors: + - "cpu" + - "gpus": "all" diff --git a/ts/handler_utils/timer.py b/ts/handler_utils/timer.py new file mode 100644 index 0000000000..a747eea3c2 --- /dev/null +++ b/ts/handler_utils/timer.py @@ -0,0 +1,66 @@ +""" +Decorator for timing handler methods + +Use this decorator to compute the execution time for your preprocesss, inference and +postprocess methods. +By default this feature is not enabled. + +To enable this, add the following section in your model-config.yaml file + +handler: + profile: true + +An example of running benchmarks with the profiling enabled is in +https://github.com/pytorch/serve/tree/master/examples/benchmarking/resnet50 + +""" + +import time + +import torch + + +def timed(func): + def wrap_func(self, *args, **kwargs): + # Measure time if config specified in model_yaml_config + if ( + "handler" in self.context.model_yaml_config + and "profile" in self.context.model_yaml_config["handler"] + ): + if self.context.model_yaml_config["handler"]["profile"]: + # Measure start time + if torch.cuda.is_available(): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + else: + start = time.time() + + result = func(self, *args, **kwargs) + + # Measure end time + if torch.cuda.is_available(): + end.record() + torch.cuda.synchronize() + else: + end = time.time() + + # Measure time taken to execute the function in miliseconds + if torch.cuda.is_available(): + duration = start.elapsed_time(end) + else: + duration = (end - start) * 1000 + + # Add metrics for profiling + metrics = self.context.metrics + metrics.add_time("ts_handler_" + func.__name__, duration) + else: + # If profile config specified in model_yaml_config is False + result = func(self, *args, **kwargs) + else: + # If no profile config specified in model_yaml_config + result = func(self, *args, **kwargs) + + return result + + return wrap_func diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 2e3e716a6b..227a4ec56c 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -12,6 +12,8 @@ import torch from pkg_resources import packaging +from ts.handler_utils.timer import timed + from ..utils.util import ( check_valid_pt2_backend, list_classes_from_module, @@ -77,7 +79,8 @@ ONNX_AVAILABLE = False try: - import torch_tensorrt + import torch_tensorrt # nopycln: import + logger.info("Torch TensorRT enabled") except ImportError: logger.warning("Torch TensorRT not enabled") @@ -265,6 +268,7 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path): model.load_state_dict(state_dict) return model + @timed def preprocess(self, data): """ Preprocess function to convert the request input to a tensor(Torchserve supported format). @@ -279,6 +283,7 @@ def preprocess(self, data): return torch.as_tensor(data, device=self.device) + @timed def inference(self, data, *args, **kwargs): """ The Inference Function is used to make a prediction call on the given input request. @@ -296,6 +301,7 @@ def inference(self, data, *args, **kwargs): results = self.model(marshalled_data, *args, **kwargs) return results + @timed def postprocess(self, data): """ The post process function makes use of the output from the inference and converts into a diff --git a/ts/torch_handler/image_classifier.py b/ts/torch_handler/image_classifier.py index f43eac5e6f..ef194d3924 100644 --- a/ts/torch_handler/image_classifier.py +++ b/ts/torch_handler/image_classifier.py @@ -5,8 +5,10 @@ import torch.nn.functional as F from torchvision import transforms +from ts.handler_utils.timer import timed + +from ..utils.util import map_class_to_label from .vision_handler import VisionHandler -from ..utils.util import map_class_to_label class ImageClassifier(VisionHandler): @@ -18,13 +20,14 @@ class ImageClassifier(VisionHandler): topk = 5 # These are the standard Imagenet dimensions # and statistics - image_processing = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) + image_processing = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) def set_max_result_classes(self, topk): self.topk = topk @@ -32,6 +35,7 @@ def set_max_result_classes(self, topk): def get_max_result_classes(self): return self.topk + @timed def postprocess(self, data): ps = F.softmax(data, dim=1) probs, classes = torch.topk(ps, self.topk, dim=1) diff --git a/ts/torch_handler/unit_tests/test_utils/mock_context.py b/ts/torch_handler/unit_tests/test_utils/mock_context.py index 4ee1aeb4ec..287074c6eb 100644 --- a/ts/torch_handler/unit_tests/test_utils/mock_context.py +++ b/ts/torch_handler/unit_tests/test_utils/mock_context.py @@ -21,6 +21,7 @@ def __init__( model_file="model.py", gpu_id="0", model_name="mnist", + model_yaml_config_file=None, ): self.manifest = {"model": {}} if model_pt_file: @@ -36,6 +37,12 @@ def __init__( self.explain = False self.metrics = MetricsStore(uuid.uuid4(), model_name) + self.model_yaml_config = {} + + if model_yaml_config_file: + self.model_yaml_config = get_yaml_config( + os.path.join(model_dir, model_yaml_config_file) + ) def get_request_header(self, idx, exp): if idx and exp: diff --git a/ts/torch_handler/vision_handler.py b/ts/torch_handler/vision_handler.py index 0ad08af327..9d7778b41d 100644 --- a/ts/torch_handler/vision_handler.py +++ b/ts/torch_handler/vision_handler.py @@ -11,6 +11,8 @@ from captum.attr import IntegratedGradients from PIL import Image +from ts.handler_utils.timer import timed + from .base_handler import BaseHandler @@ -27,6 +29,7 @@ def initialize(self, context): if not properties.get("limit_max_image_pixels"): Image.MAX_IMAGE_PIXELS = None + @timed def preprocess(self, data): """The preprocess function of MNIST program converts the input data to a float tensor