diff --git a/.gitignore b/.gitignore index 2ea92832..5c5311bb 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,8 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +dataset/ +initial_model/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/core/testcasecontroller/algorithm/paradigm/base.py b/core/testcasecontroller/algorithm/paradigm/base.py index cf36cd4e..3fe39267 100644 --- a/core/testcasecontroller/algorithm/paradigm/base.py +++ b/core/testcasecontroller/algorithm/paradigm/base.py @@ -124,6 +124,6 @@ def build_paradigm_job(self, paradigm_type): ) # pylint: disable=E1101 if paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value: - return self.modules_funcs.get(ModuleType.BASEMODEL.value)() + return self.module_instances.get(ModuleType.BASEMODEL.value) return None diff --git a/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py b/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py index cf8ef521..4085eafd 100644 --- a/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py +++ b/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py @@ -16,6 +16,10 @@ import os +# pylint: disable=E0401 +import onnx + +from core.common.log import LOGGER from core.common.constant import ParadigmType from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase @@ -63,8 +67,15 @@ def run(self): """ job = self.build_paradigm_job(ParadigmType.MULTIEDGE_INFERENCE.value) - - inference_result = self._inference(job, self.initial_model) + if not job.__dict__.get('model_parallel'): + inference_result = self._inference(job, self.initial_model) + else: + if 'partition' in dir(job): + models_dir, map_info = job.partition(self.initial_model) + else: + models_dir, map_info = self._partition(job.__dict__.get('partition_point_list'), + self.initial_model, os.path.dirname(self.initial_model)) + inference_result = self._inference_mp(job, models_dir, map_info) return inference_result, self.system_metric_info @@ -77,3 +88,26 @@ def _inference(self, job, trained_model): job.load(trained_model) infer_res = job.predict(inference_dataset.x, train_dataset=train_dataset) return infer_res + + def _inference_mp(self, job, models_dir, map_info): + inference_dataset = self.dataset.load_data(self.dataset.test_url, "inference") + inference_output_dir = os.path.join(self.workspace, "output/inference/") + os.environ["RESULT_SAVED_URL"] = inference_output_dir + job.load(models_dir, map_info) + infer_res = job.predict(inference_dataset.x) + return infer_res + + # pylint: disable=W0718, C0103 + def _partition(self, partition_point_list, initial_model_path, sub_model_dir): + map_info = dict({}) + for idx, point in enumerate(partition_point_list): + input_names = point['input_names'] + output_names = point['output_names'] + sub_model_path = sub_model_dir + '/' + 'sub_model_' + str(idx+1) + '.onnx' + try: + onnx.utils.extract_model(initial_model_path, + sub_model_path, input_names, output_names) + except Exception as e: + LOGGER.info(str(e)) + map_info[sub_model_path.split('/')[-1]] = point['device_name'] + return sub_model_dir, map_info diff --git a/examples/imagenet/multiedge_inference_bench/README.md b/examples/imagenet/multiedge_inference_bench/README.md new file mode 100644 index 00000000..ae0fd01e --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/README.md @@ -0,0 +1,104 @@ +# Benchmarking of Image Clasification for High Mobility Scenarios + +In high-mobility scenarios such as highways and high-speed railways, the connection between personal terminal devices and cloud servers is significantly weakened. However, in recent years, artificial intelligence technology has permeated every aspect of our lives, and we also need to use artificial intelligence technologies with high computational and storage demands and sensitive to latency in high-mobility scenarios. For example, even when driving through a tunnel with a weak network environment, we may still need to use AI capabilities such as image classification. Therefore, in the event that edge devices lose connection with the cloud, offloading AI computing tasks to adjacent edge devices and achieving computational aggregation based on the mutual collaboration between devices, to complete computing tasks that traditionally require cloud-edge collaboration, has become an issue worth addressing. This benchmarking job aims to simulate such scenario: using multiple heterogeneous computing units on the edge (such as personal mobile phones, tablets, bracelets, laptops, and other computing devices) for collaborative ViT inference, enabling image classification to be completed with lower latency using devices that are closer to the edge, thereby enhancing the user experience.After running benchmarking jobs, a report will be generated. + +With Ianvs installed and related environment prepared, users is then able to run the benchmarking process using the following steps. If you haven't installed Ianvs, please refer to [how-to-install-ianvs](../../../docs/guides/how-to-install-ianvs.md). + +## Prerequisites + +To setup the environment, run the following commands: +```shell +cd +pip install ./examples/resources/third_party/* +pip install -r requirements.txt +cd ./examples/imagenet/multiedge_inference_bench/ +pip install -r requirements.txt +cd +mkdir dataset initial_model +``` +Please refer to [this link](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) and ensure that the versions of CUDA and cuDNN are compatible with the version of ONNX Runtime. + +Note that it is advisable to avoid using lower versions of the ONNX library, as they are very time-consuming when performing computational graph partitioning. The version of onnx-cuda-cudnn we used in our tests is as follows: +![onnx_version](images/onnx_version.png) + +## Step 1. Prepare Dataset +Download [ImageNet 2012 dataset](https://image-net.org/download.php) and put it under /dataset in the following structure: + +``` +dataset + |------ILSVRC2012_devkit_t12.tar.gz + |------ILSVRC2012_img_val.tar +``` +Then, you need to process the dataset and generate the _train.txt_ and _val.txt_: + +```shell +cd +python ./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py +``` + +## Step 2. Prepare Model + +Next, download pretrained model via [[huggingface]](https://huggingface.co/optimum/vit-base-patch16-224/tree/main), rename it to vit-base-patch16-224.onnx and put it under /initial_model/ + +## Step 3. Run Benchmarking Job - Manual +We are now ready to run the ianvs for benchmarking image classification for high mobility scenarios on the ImageNet dataset. + +```python +ianvs -f ./examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml +``` + +The benchmarking process takes a few minutes and varies depending on devices. + +## Step 4. Check the Result + +Finally, the user can check the result of benchmarking on the console and also in the output path (/ianvs/multiedge_inference_bench/workspace) defined in the benchmarking config file (classification_job.yaml). + +The final output might look like this: +![result](images/result.png) + +You can view the graphical representation of relevant metrics in /ianvs/multiedge_inference_bench/workspace/images/, such as the following: +![plot](images/plot.png) + +To compare the running conditions of the model with and without parallelism in the multiedge inference scenario, you can modify the value of --devices_info in base_model.py to devices_one.yaml to view the relevant metrics when the model runs on a single device. + +## Step 5. Run Benchmarking Job - Automatic +We offer a profiling-based and memory matching partition algorithm to compare with the method of manually specifying partitioning points. This method prioritizes the memory matching between the computational subgraph and the device. First, we profile the initial model on the CPU to collect memory usage, the number of parameters, computational cost, and the input and output data shapes for each layer, as well as the total number of layers and their names in the entire model. To facilitate subsequent integration, we have implemented profiling for three types of transformer models: vit, bert, and deit. Secondly, based on the results of the profiling and the device information provided in devices.yaml, we can identify the partitioning point that matches the device memory through a single traversal and perform model partitioning. + +You should first run the following command to generate a profiling result: +```shell +cd +python ./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py +``` + +Then you will find a profiler_results.yml file in the /examples/imagenet/multiedge_inference_bench/testalgorithms/automatic directory, just like this: +![profiler_result](images/profiler_results.png) + +Then you can run the following command to perform benchmarking: +```shell +ianvs -f ./examples/imagenet/multiedge_inference_bench/classification_job_auto.yaml +``` + +After running, you will see the profit from the automatic method compared with the manual method. +![result](images/auto_result.png) + +## Explanation for devices.yaml + +This file defines the specific information of edge-side multi-devices and the model's partition points. The devices section includes the computing resource type, memory, frequency, and bandwidth for each device. The partition_points section defines the input and output names of each computational subgraph and their mapping relationships with devices. This benchmarking job achieves the partitioning of the computational graph and model parallelism by manually defining partition points. You can implement custom partitioning algorithms based on the rich device information in devices.yaml. + +## Custom Partitioning Algorithms + +How to partition an ONNX model based on device information is an interesting question. You can solve this issue using greedy algorithms, dynamic programming algorithms, or other innovative graph algorithms to achieve optimal resource utilization and the lowest inference latency. + +More partitioning algorithms will be added in the future and you can customize their own partition methods in basemodel.py, they only need to comply with the input and output specifications defined by the interface as follows: + +``` +def partiton(self, initial_model): + ## 1. parsing + ## 2. modeling + ## 3. partition + return models_dir, map_info +``` + +Hope you have a perfect journey in solving this problem! + + diff --git a/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml b/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml new file mode 100644 index 00000000..4e236e71 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of benchmarking; string type; + name: "classification_job" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "./multiedge_inference_bench/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "classification" + # # the url address of test algorithm configuration file; string type; + # # the file format supports yaml/yml; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "mota": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "all" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml b/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml new file mode 100644 index 00000000..1b251c38 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of benchmarking; string type; + name: "classification_job" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "./multiedge_inference_bench/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "classification" + # # the url address of test algorithm configuration file; string type; + # # the file format supports yaml/yml; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "mota": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "all" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/imagenet/multiedge_inference_bench/images/auto_result.png b/examples/imagenet/multiedge_inference_bench/images/auto_result.png new file mode 100644 index 00000000..e75443db Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/auto_result.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/onnx_version.png b/examples/imagenet/multiedge_inference_bench/images/onnx_version.png new file mode 100644 index 00000000..076e4a16 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/onnx_version.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/plot.png b/examples/imagenet/multiedge_inference_bench/images/plot.png new file mode 100644 index 00000000..c3f93796 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/plot.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/profiler_results.png b/examples/imagenet/multiedge_inference_bench/images/profiler_results.png new file mode 100644 index 00000000..9b46bb37 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/profiler_results.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/result.png b/examples/imagenet/multiedge_inference_bench/images/result.png new file mode 100644 index 00000000..9a4ef272 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/result.png differ diff --git a/examples/imagenet/multiedge_inference_bench/requirements.txt b/examples/imagenet/multiedge_inference_bench/requirements.txt new file mode 100644 index 00000000..18b77692 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/requirements.txt @@ -0,0 +1,113 @@ +absl-py==2.1.0 +addict==2.4.0 +asgiref==3.7.2 +cachetools==5.3.3 +certifi==2024.2.2 +charset-normalizer==3.3.2 +click==8.1.7 +coloredlogs==15.0.1 +colorlog==4.7.2 +cycler==0.11.0 +Cython==3.0.10 +cython-bbox==0.1.5 +fastapi==0.68.2 +filelock==3.12.2 +filterpy==1.4.5 +flatbuffers==24.3.25 +fonttools==4.38.0 +fpdf==1.7.2 +fsspec==2023.1.0 +google-auth==2.29.0 +google-auth-oauthlib==0.4.6 +grpcio==1.62.2 +h11==0.14.0 +h5py==3.8.0 +huggingface-hub==0.16.4 +humanfriendly==10.0 +ianvs==0.1.0 +idna==3.7 +imageio==2.31.2 +importlib-metadata==6.7.0 +joblib==1.2.0 +kiwisolver==1.4.5 +lap==0.4.0 +loguru==0.7.2 +Markdown==3.4.4 +markdown-it-py==2.2.0 +MarkupSafe==2.1.5 +matplotlib==3.5.3 +mdurl==0.1.2 +minio==7.0.4 +mmcv==1.5.0 +mmengine==0.10.4 +motmetrics==1.4.0 +mpmath==1.3.0 +networkx==2.6.3 +ninja==1.11.1.1 +numpy==1.21.6 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +oauthlib==3.2.2 +onnx==1.14.1 +onnx-simplifier==0.3.5 +onnxoptimizer==0.3.13 +onnxruntime==1.14.1 +onnxruntime-gpu==1.14.1 +opencv-python==4.9.0.80 +packaging==24.0 +pandas==1.3.5 +Pillow==9.5.0 +platformdirs==4.0.0 +prettytable==2.5.0 +protobuf==3.20.3 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycocotools==2.0.7 +pydantic==1.10.15 +Pygments==2.17.2 +pynvml==11.5.3 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +pytz==2024.1 +PyWavelets==1.3.0 +PyYAML==6.0.1 +regex==2024.4.16 +requests==2.31.0 +requests-oauthlib==2.0.0 +rich==13.7.1 +rsa==4.9 +safetensors==0.4.5 +scikit-image==0.19.3 +scikit-learn==1.0.2 +scipy==1.7.3 +seaborn==0.12.2 +six==1.15.0 +starlette==0.14.2 +sympy==1.10.1 +tabulate==0.9.0 +tenacity==8.0.1 +tensorboard==2.11.2 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +termcolor==2.3.0 +thop==0.1.1.post2209072238 +threadpoolctl==3.1.0 +tifffile==2021.11.2 +tokenizers==0.13.3 +tomli==2.0.1 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.66.4 +transformers==4.30.2 +typing_extensions==4.7.1 +urllib3==2.0.7 +uvicorn==0.14.0 +wcwidth==0.2.13 +websockets==9.1 +Werkzeug==2.2.3 +xmltodict==0.13.0 +yapf==0.40.2 +-e git+https://github.com/ifzhang/ByteTrack.git@d1bf0191adff59bc8fcfeaa0b33d3d1642552a99#egg=yolox +zipp==3.15.0 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py new file mode 100644 index 00000000..a35e5483 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py @@ -0,0 +1,205 @@ +# Modified Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +from collections import OrderedDict +from pathlib import Path +from collections import defaultdict +import time + +from sedna.common.class_factory import ClassType, ClassFactory +from dataset import load_dataset +import model_cfg + +import yaml +import onnxruntime as ort +from torch.utils.data import DataLoader +import torch +import numpy as np +from tqdm import tqdm +import pynvml + + +__all__ = ["BaseModel"] + +# set backend +os.environ["BACKEND_TYPE"] = "ONNX" + + +def make_parser(): + parser = argparse.ArgumentParser("ViT Eval") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument("--devices_info", default="./devices.yaml", type=str, help="devices conf") + parser.add_argument("--profiler_info", default="./profiler_results.yml", type=str, help="profiler results") + parser.add_argument("--model_parallel", default=True, action="store_true") + parser.add_argument("--split", default="val", type=str, help="split of dataset") + parser.add_argument("--indices", default=None, type=str, help="indices of dataset") + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle data") + parser.add_argument("--model_name", default="google/vit-base-patch16-224", type=str, help="model name") + parser.add_argument("--dataset_name", default="ImageNet", type=str, help="dataset name") + parser.add_argument("--data_size", default=1000, type=int, help="data size to inference") + # remove conflict with ianvs + parser.add_argument("-f") + return parser + + +@ClassFactory.register(ClassType.GENERAL, alias="Classification") +class BaseModel: + + def __init__(self, **kwargs) -> None: + self.args = make_parser().parse_args() + self.model_parallel = self.args.model_parallel + self.models = [] + self.devices_info_url = str(Path(Path(__file__).parent.resolve(), self.args.devices_info)) + self.device_info = self._parse_yaml(self.devices_info_url) + self.profiler_info_url = str(Path(Path(__file__).parent.resolve(), self.args.profiler_info)) + self.profiler_info = self._parse_yaml(self.profiler_info_url) + self.partition_point_list = [] + return + + ## auto partition by memory usage + def partition(self, initial_model): + map_info = {} + def _partition_model(pre, cur, flag): + print("========= Sub Model {} Partition =========".format(flag)) + model = model_cfg.module_shard_factory(self.args.model_name, initial_model, pre+1, cur+1, 1) + dummy_input = torch.randn(1, *self.profiler_info.get('profile_data')[pre].get("shape_in")[0]) + torch.onnx.export(model, + dummy_input, + str(Path(Path(initial_model).parent.resolve())) + "/sub_model_" + str(flag) + ".onnx", + export_params=True, + opset_version=16, + do_constant_folding=True, + input_names=['input_' + str(pre+1)], + output_names=['output_' + str(cur+1)]) + self.partition_point_list.append({ + 'input_names': ['input_' + str(pre+1)], + 'output_names': ['output_' + str(cur+1)] + }) + map_info["sub_model_" + str(flag) + ".onnx"] = self.device_info.get('devices')[flag-1].get("name") + + layer_list = [(layer.get("memory"), len(layer.get("shape_out"))) for layer in self.profiler_info.get('profile_data')] + total_model_memory = sum([layer[0] for layer in layer_list]) + devices_memory = [int(device.get('memory')) for device in self.device_info.get('devices')] + total_devices_memory = sum(devices_memory) + devices_memory = [per_mem * total_model_memory / total_devices_memory for per_mem in devices_memory] + + flag = 0 + sum_ = 0 + pre = 0 + for cur, layer in enumerate(layer_list): + if flag == len(devices_memory)-1: + cur = len(layer_list) + _partition_model(pre, cur-1, flag+1) + break + elif layer[1] == 1 and sum_ >= devices_memory[flag]: + sum_ = 0 + flag += 1 + _partition_model(pre, cur, flag) + pre = cur + 1 + else: + sum_ += layer[0] + continue + return str(Path(Path(initial_model).parent.resolve())), map_info + + + def load(self, models_dir=None, map_info=None) -> None: + cnt = 0 + for model_name, device in map_info.items(): + model = models_dir + '/' + model_name + if not os.path.exists(model): + raise ValueError("=> No modle found at '{}'".format(model)) + if device == 'cpu': + session = ort.InferenceSession(model, providers=['CPUExecutionProvider']) + elif 'gpu' in device: + device_id = int(device.split('-')[-1]) + session = ort.InferenceSession(model, providers=[('CUDAExecutionProvider', {'device_id': device_id})]) + else: + raise ValueError("Error device info: '{}'".format(device)) + self.models.append({ + 'session': session, + 'name': model_name, + 'device': device, + 'input_names': self.partition_point_list[cnt]['input_names'], + 'output_names': self.partition_point_list[cnt]['output_names'], + }) + cnt += 1 + print("=> Loaded onnx model: '{}'".format(model)) + return + + def predict(self, data, input_shape=None, **kwargs): + pynvml.nvmlInit() + root = str(Path(data[0]).parents[2]) + dataset_cfg = { + 'name': self.args.dataset_name, + 'root': root, + 'split': self.args.split, + 'indices': self.args.indices, + 'shuffle': self.args.shuffle + } + data_loader, ids = self._get_eval_loader(dataset_cfg) + data_loader = tqdm(data_loader, desc='Evaluating', unit='batch') + pred = [] + inference_time_per_device = defaultdict(int) + power_usage_per_device = defaultdict(list) + mem_usage_per_device = defaultdict(list) + cnt = 0 + for data, id in zip(data_loader, ids): + outputs = data[0].numpy() + for model in self.models: + start_time = time.time() + outputs = model['session'].run(None, {model['input_names'][0]: outputs})[0] + end_time = time.time() + device = model.get('device') + inference_time_per_device[device] += end_time - start_time + if 'gpu' in device and cnt % 100 == 0: + handle = pynvml.nvmlDeviceGetHandleByIndex(int(device.split('-')[-1])) + power_usage = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle).used / (1024**2) + power_usage_per_device[device] += [power_usage] + mem_usage_per_device[device] += [memory_info] + max_ids = np.argmax(outputs) + pred.append((max_ids, id)) + cnt += 1 + data_loader.close() + result = dict({}) + result["pred"] = pred + result["inference_time_per_device"] = inference_time_per_device + result["power_usage_per_device"] = power_usage_per_device + result["mem_usage_per_device"] = mem_usage_per_device + return result + + + def _get_eval_loader(self, dataset_cfg): + model_name = self.args.model_name + data_size = self.args.data_size + dataset, _, ids = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + return data_loader, ids + + def _parse_yaml(self, url): + """Convert yaml file to the dict.""" + if url.endswith('.yaml') or url.endswith('.yml'): + with open(url, "rb") as file: + info_dict = yaml.load(file, Loader=yaml.SafeLoader) + return info_dict + else: + raise RuntimeError('config file must be the yaml format') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml new file mode 100644 index 00000000..cea77f57 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml @@ -0,0 +1,27 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "multiedgeinference" + # the url address of initial model; string type; optional; + initial_model_url: "./initial_model/ViT-B_16-224.npz" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "Classification" + # the url address of python module; string type; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 1 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py new file mode 100644 index 00000000..9b4ee16c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py @@ -0,0 +1,71 @@ +import logging +import random +from typing import Callable, Optional, Sequence +import os + +from torch.utils.data import DataLoader, Dataset, Subset +from transformers import ViTFeatureExtractor +from torchvision.datasets import ImageNet + + +def load_dataset_imagenet(feature_extractor: Callable, root: str, split: str='train') -> Dataset: + """Get the ImageNet dataset.""" + + def transform(img): + pixels = feature_extractor(images=img.convert('RGB'), return_tensors='pt')['pixel_values'] + return pixels[0] + return ImageNet(root, split=split, transform=transform) + +def load_dataset_subset(dataset: Dataset, indices: Optional[Sequence[int]]=None, + max_size: Optional[int]=None, shuffle: bool=False) -> Dataset: + """Get a Dataset subset.""" + if indices is None: + indices = list(range(len(dataset))) + if shuffle: + random.shuffle(indices) + if max_size is not None: + indices = indices[:max_size] + image_paths = [] + for index in indices: + image_paths.append(dataset.imgs[index][0]) + return Subset(dataset, indices), image_paths, indices + +def load_dataset(dataset_cfg: dict, model_name: str, batch_size: int) -> Dataset: + """Load inputs based on model.""" + def _get_feature_extractor(): + feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) + return feature_extractor + dataset_name = dataset_cfg['name'] + dataset_root = dataset_cfg['root'] + dataset_split = dataset_cfg['split'] + indices = dataset_cfg['indices'] + dataset_shuffle = dataset_cfg['shuffle'] + if dataset_name == 'ImageNet': + if dataset_root is None: + dataset_root = 'ImageNet' + logging.info("Dataset root not set, assuming: %s", dataset_root) + feature_extractor = _get_feature_extractor() + dataset = load_dataset_imagenet(feature_extractor, dataset_root, split=dataset_split) + dataset, paths, ids = load_dataset_subset(dataset, indices=indices, max_size=batch_size, + shuffle=dataset_shuffle) + return dataset, paths, ids + +if __name__ == '__main__': + dataset_cfg = { + 'name': "ImageNet", + 'root': './dataset', + 'split': 'val', + 'indices': None, + 'shuffle': False, + } + model_name = "google/vit-base-patch16-224" + ## Total images to be inferenced. + data_size = 1000 + dataset, paths, _ = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + with open('./dataset/train.txt', 'w') as f: + for i, (image, label) in enumerate(data_loader): + original_path = paths[i].replace('/dataset', '') + f.write(f"{original_path} {label.item()}\n") + f.close() + os.popen('cp ./dataset/train.txt ./dataset/test.txt') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py new file mode 100644 index 00000000..6093e150 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py @@ -0,0 +1,24 @@ +"""Common device configuration.""" +from typing import Tuple, Union +import torch + +# The torch.device to use for computation +DEVICE = None + +def forward_pre_hook_to_device(_module, inputs) \ + -> Union[Tuple[torch.tensor], Tuple[Tuple[torch.Tensor]]]: + """Move tensors to the compute device (e.g., GPU), if needed.""" + assert isinstance(inputs, tuple) + assert len(inputs) == 1 + if isinstance(inputs[0], torch.Tensor): + inputs = (inputs,) + tensors_dev = tuple(t.to(device=DEVICE) for t in inputs[0]) + return tensors_dev if len(tensors_dev) == 1 else (tensors_dev,) + +def forward_hook_to_cpu(_module, _inputs, outputs) -> Union[torch.tensor, Tuple[torch.Tensor]]: + """Move tensors to the CPU, if needed.""" + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + assert isinstance(outputs, tuple) + tensors_cpu = tuple(t.cpu() for t in outputs) + return tensors_cpu[0] if len(tensors_cpu) == 1 else tensors_cpu diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml new file mode 100644 index 00000000..1317c4ac --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml @@ -0,0 +1,16 @@ +devices: + - name: "gpu-0" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "100" + - name: "gpu-1" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "80" + - name: "gpu-2" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "90" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py new file mode 100644 index 00000000..df41a426 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py @@ -0,0 +1,93 @@ +"""Model configurations and default parameters.""" +import logging +from typing import Any, Callable, List, Optional, Tuple +from transformers import AutoConfig +from models import ModuleShard, ModuleShardConfig +from models.transformers import bert, deit, vit +import devices + +_logger = logging.getLogger(__name__) + +_MODEL_CONFIGS = {} + +def _model_cfg_add(name, layers, weights_file, shard_module): + _MODEL_CONFIGS[name] = { + 'name': name, + 'layers': layers, + 'weights_file': weights_file, + 'shard_module': shard_module, + } + +# Transformer blocks can be split 4 ways, e.g., where ViT-Base has 12 layers, we specify 12*4=48 +_model_cfg_add('google/vit-base-patch16-224', 48, './initial_model/ViT-B_16-224.npz', + vit.ViTShardForImageClassification) +_model_cfg_add('google/vit-large-patch16-224', 96, 'ViT-L_16-224.npz', + vit.ViTShardForImageClassification) +# NOTE: This ViT-Huge model doesn't include classification, so the config must be extended +_model_cfg_add('google/vit-huge-patch14-224-in21k', 128, 'ViT-H_14.npz', + vit.ViTShardForImageClassification) +# NOTE: BertModelShard alone doesn't do classification +_model_cfg_add('bert-base-uncased', 48, 'BERT-B.npz', + bert.BertModelShard) +_model_cfg_add('bert-large-uncased', 96, 'BERT-L.npz', + bert.BertModelShard) +_model_cfg_add('textattack/bert-base-uncased-CoLA', 48, 'BERT-B-CoLA.npz', + bert.BertShardForSequenceClassification) +_model_cfg_add('facebook/deit-base-distilled-patch16-224', 48, 'DeiT_B_distilled.npz', + deit.DeiTShardForImageClassification) +_model_cfg_add('facebook/deit-small-distilled-patch16-224', 48, 'DeiT_S_distilled.npz', + deit.DeiTShardForImageClassification) +_model_cfg_add('facebook/deit-tiny-distilled-patch16-224', 48, 'DeiT_T_distilled.npz', + deit.DeiTShardForImageClassification) + +def get_model_names() -> List[str]: + """Get a list of available model names.""" + return list(_MODEL_CONFIGS.keys()) + +def get_model_dict(model_name: str) -> dict: + """Get a model's key/value properties - modify at your own risk.""" + return _MODEL_CONFIGS[model_name] + +def get_model_layers(model_name: str) -> int: + """Get a model's layer count.""" + return _MODEL_CONFIGS[model_name]['layers'] + +def get_model_config(model_name: str) -> Any: + """Get a model's config.""" + # We'll need more complexity if/when we add support for models not from `transformers` + config = AutoConfig.from_pretrained(model_name) + # Config overrides + if model_name == 'google/vit-huge-patch14-224-in21k': + # ViT-Huge doesn't include classification, so we have to set this ourselves + # NOTE: not setting 'id2label' or 'label2id' + config.num_labels = 21843 + return config + +def get_model_default_weights_file(model_name: str) -> str: + """Get a model's default weights file name.""" + return _MODEL_CONFIGS[model_name]['weights_file'] + +def save_model_weights_file(model_name: str, model_file: Optional[str]=None) -> None: + """Save a model's weights file.""" + if model_file is None: + model_file = get_model_default_weights_file(model_name) + # This works b/c all shard implementations have the same save_weights interface + module = _MODEL_CONFIGS[model_name]['shard_module'] + module.save_weights(model_name, model_file) + +def module_shard_factory(model_name: str, model_file: Optional[str], layer_start: int, + layer_end: int, stage: int) -> ModuleShard: + """Get a shard instance on the globally-configured `devices.DEVICE`.""" + # This works b/c all shard implementations have the same constructor interface + if model_file is None: + model_file = get_model_default_weights_file(model_name) + config = get_model_config(model_name) + is_first = layer_start == 1 + is_last = layer_end == get_model_layers(model_name) + shard_config = ModuleShardConfig(layer_start=layer_start, layer_end=layer_end, + is_first=is_first, is_last=is_last) + module = _MODEL_CONFIGS[model_name]['shard_module'] + shard = module(config, shard_config, model_file) + _logger.info("======= %s Stage %d =======", module.__name__, stage) + shard.to(device=devices.DEVICE) + return shard \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py new file mode 100644 index 00000000..bbb74188 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py @@ -0,0 +1,49 @@ +"""Models module.""" +from typing import Any, Tuple, Type, Union +from torch import nn, Tensor + +ModuleShardData: Type = Union[Tensor, Tuple[Tensor, ...]] +"""A module shard input/output type.""" + + +class ModuleShardConfig: + """Base class for shard configurations (distinct from model configurations).""" + # pylint: disable=too-few-public-methods + + def __init__(self, **kwargs: dict): + # Attributes with default values + self.layer_start: int = kwargs.pop('layer_start', 0) + self.layer_end: int = kwargs.pop('layer_end', 0) + self.is_first: bool = kwargs.pop('is_first', False) + self.is_last: bool = kwargs.pop('is_last', False) + + # Attributes without default values + for key, value in kwargs.items(): + setattr(self, key, value) + + +class ModuleShard(nn.Module): + """Abstract parent class for module shards.""" + # pylint: disable=abstract-method + + def __init__(self, config: Any, shard_config: ModuleShardConfig): + super().__init__() + self.config = config + self.shard_config = shard_config + + def has_layer(self, layer: int) -> bool: + """Check if shard has the specified layer.""" + return layer in range(self.shard_config.layer_start, self.shard_config.layer_end + 1) + + +def get_microbatch_size(shard_data: ModuleShardData, verify: bool=False): + """Get the microbatch size from shard data.""" + if isinstance(shard_data, Tensor): + shard_data = (shard_data,) + ubatch_size = 0 if len(shard_data) == 0 else len(shard_data[0]) + if verify: + # Sanity check that tensors are the same length + for tensor in shard_data: + assert isinstance(tensor, Tensor) + assert len(tensor) == ubatch_size + return ubatch_size diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py new file mode 100644 index 00000000..532c96da --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py @@ -0,0 +1,6 @@ +"""Transformers module.""" +from typing import Tuple, Type, Union +from torch import Tensor + +TransformerShardData: Type = Union[Tensor, Tuple[Tensor, Tensor]] +"""A transformer shard input/output type.""" diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py new file mode 100644 index 00000000..e33fe989 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py @@ -0,0 +1,219 @@ +"""BERT transformers.""" +from collections.abc import Mapping +import logging +import math +from typing import Union +import numpy as np +import torch +from torch import nn +from transformers import BertConfig, BertForSequenceClassification, BertModel +from transformers.models.bert.modeling_bert import ( + BertEmbeddings, BertIntermediate, BertOutput, BertPooler, BertSelfAttention, BertSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + + +class BertLayerShard(ModuleShard): + """Module shard based on `BertLayer`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.self_attention = None + self.self_output = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.self_attention = BertSelfAttention(self.config) + if self.has_layer(1): + self.self_output = BertSelfOutput(self.config) + if self.has_layer(2): + self.intermediate = BertIntermediate(self.config) + if self.has_layer(3): + self.output = BertOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data = (self.self_attention(data)[0], data) + if self.has_layer(1): + data = self.self_output(data[0], data[1]) + if self.has_layer(2): + data = (self.intermediate(data), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class BertModelShard(ModuleShard): + """Module shard based on `BertModel`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # BertModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a BertEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.pooler = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = BertEmbeddings(self.config) + self.embeddings.eval() + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = BertLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load pooler for the last shard") + self.pooler = BertPooler(self.config) + self.pooler.eval() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.position_ids.copy_(torch.from_numpy((weights["embeddings.position_ids"]))) + self.embeddings.word_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.word_embeddings.weight'])) + self.embeddings.position_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.position_embeddings.weight'])) + self.embeddings.token_type_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.token_type_embeddings.weight'])) + self.embeddings.LayerNorm.weight.copy_(torch.from_numpy(weights['embeddings.LayerNorm.weight'])) + self.embeddings.LayerNorm.bias.copy_(torch.from_numpy(weights['embeddings.LayerNorm.bias'])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.pooler.dense.weight.copy_(torch.from_numpy(weights["pooler.dense.weight"])) + self.pooler.dense.bias.copy_(torch.from_numpy(weights['pooler.dense.bias'])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"encoder.layer.{layer_id}." + if layer.has_layer(0): + layer.self_attention.query.weight.copy_(torch.from_numpy(weights[root + "attention.self.query.weight"])) + layer.self_attention.key.weight.copy_(torch.from_numpy(weights[root + "attention.self.key.weight"])) + layer.self_attention.value.weight.copy_(torch.from_numpy(weights[root + "attention.self.value.weight"])) + layer.self_attention.query.bias.copy_(torch.from_numpy(weights[root + "attention.self.query.bias"])) + layer.self_attention.key.bias.copy_(torch.from_numpy(weights[root + "attention.self.key.bias"])) + layer.self_attention.value.bias.copy_(torch.from_numpy(weights[root + "attention.self.value.bias"])) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "attention.output.dense.weight"])) + layer.self_output.LayerNorm.weight.copy_(torch.from_numpy(weights[root + "attention.output.LayerNorm.weight"])) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "attention.output.dense.bias"])) + layer.self_output.LayerNorm.bias.copy_(torch.from_numpy(weights[root + "attention.output.LayerNorm.bias"])) + if layer.has_layer(2): + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "intermediate.dense.weight"])) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "intermediate.dense.bias"])) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "output.dense.weight"])) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "output.dense.bias"])) + layer.output.LayerNorm.weight.copy_(torch.from_numpy(weights[root + "output.LayerNorm.weight"])) + layer.output.LayerNorm.bias.copy_(torch.from_numpy(weights[root + "output.LayerNorm.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.pooler(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str) -> None: + """Save the model weights file.""" + model = BertModel.from_pretrained(model_name) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) + + +class BertShardForSequenceClassification(ModuleShard): + """Module shard based on `BertForSequenceClassification`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.bert = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner BERT model + self.bert = BertModelShard(self.config, self.shard_config, + self._extract_weights_bert(weights)) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) + self._load_weights_last(weights) + + def _extract_weights_bert(self, weights): + bert_weights = {} + for key, val in weights.items(): + if key.startswith('bert.'): + bert_weights[key[len('bert.'):]] = val + return bert_weights + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(weights['classifier.weight'])) + self.classifier.bias.copy_(torch.from_numpy(weights['classifier.bias'])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.bert(data) + if self.shard_config.is_last: + data = self.classifier(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str) -> None: + """Save the model weights file.""" + model = BertForSequenceClassification.from_pretrained(model_name) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py new file mode 100644 index 00000000..dc6b6144 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py @@ -0,0 +1,233 @@ +"""DeiT Transformers.""" +from collections.abc import Mapping +import logging +import math +from typing import Optional, Union +import numpy as np +import torch +from torch import nn +from transformers import DeiTConfig +from transformers.models.deit.modeling_deit import DeiTEmbeddings +from transformers.models.vit.modeling_vit import ( + ViTIntermediate, ViTOutput, ViTSelfAttention, ViTSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + +_HUB_MODEL_NAMES = { + 'facebook/deit-base-distilled-patch16-224': 'deit_base_distilled_patch16_224', + 'facebook/deit-small-distilled-patch16-224': 'deit_small_distilled_patch16_224', + 'facebook/deit-tiny-distilled-patch16-224': 'deit_tiny_distilled_patch16_224', +} + + +class DeiTLayerShard(ModuleShard): + """Module shard based on `DeiTLayer` (copied from `.vit.ViTLayerShard`).""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.layernorm_before = None + self.self_attention = None + self.self_output = None + self.layernorm_after = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.layernorm_before = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.self_attention = ViTSelfAttention(self.config) + if self.has_layer(1): + self.self_output = ViTSelfOutput(self.config) + if self.has_layer(2): + self.layernorm_after = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.intermediate = ViTIntermediate(self.config) + if self.has_layer(3): + self.output = ViTOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data_norm = self.layernorm_before(data) + data = (self.self_attention(data_norm)[0], data) + if self.has_layer(1): + skip = data[1] + data = self.self_output(data[0], skip) + data += skip + if self.has_layer(2): + data_norm = self.layernorm_after(data) + data = (self.intermediate(data_norm), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class DeiTModelShard(ModuleShard): + """Module shard based on `DeiTModel`.""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # DeiTModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a DeiTEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.layernorm = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = DeiTEmbeddings(self.config) + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = DeiTLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load layernorm for the last shard") + self.layernorm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.cls_token.copy_(torch.from_numpy(weights["cls_token"])) + self.embeddings.position_embeddings.copy_(torch.from_numpy((weights["pos_embed"]))) + self.embeddings.patch_embeddings.projection.weight.copy_(torch.from_numpy(weights["patch_embed.proj.weight"])) + self.embeddings.patch_embeddings.projection.bias.copy_(torch.from_numpy(weights["patch_embed.proj.bias"])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.layernorm.weight.copy_(torch.from_numpy(weights["norm.weight"])) + self.layernorm.bias.copy_(torch.from_numpy(weights["norm.bias"])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"blocks.{layer_id}." + embed_dim = self.config.hidden_size + if layer.has_layer(0): + layer.layernorm_before.weight.copy_(torch.from_numpy(weights[root + "norm1.weight"])) + layer.layernorm_before.bias.copy_(torch.from_numpy(weights[root + "norm1.bias"])) + qkv_weight = weights[root + "attn.qkv.weight"] + layer.self_attention.query.weight.copy_(torch.from_numpy(qkv_weight[0:embed_dim,:])) + layer.self_attention.key.weight.copy_(torch.from_numpy(qkv_weight[embed_dim:embed_dim*2,:])) + layer.self_attention.value.weight.copy_(torch.from_numpy(qkv_weight[embed_dim*2:embed_dim*3,:])) + qkv_bias = weights[root + "attn.qkv.bias"] + layer.self_attention.query.bias.copy_(torch.from_numpy(qkv_bias[0:embed_dim,])) + layer.self_attention.key.bias.copy_(torch.from_numpy(qkv_bias[embed_dim:embed_dim*2])) + layer.self_attention.value.bias.copy_(torch.from_numpy(qkv_bias[embed_dim*2:embed_dim*3])) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "attn.proj.weight"])) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "attn.proj.bias"])) + if layer.has_layer(2): + layer.layernorm_after.weight.copy_(torch.from_numpy(weights[root + "norm2.weight"])) + layer.layernorm_after.bias.copy_(torch.from_numpy(weights[root + "norm2.bias"])) + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "mlp.fc1.weight"])) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "mlp.fc1.bias"])) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "mlp.fc2.weight"])) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "mlp.fc2.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.layernorm(data) + return data + + # NOTE: repo has a dependency on the timm package, which isn't an automatic torch dependency + @staticmethod + def save_weights(model_name: str, model_file: str, hub_repo: str='facebookresearch/deit:main', + hub_model_name: Optional[str]=None) -> None: + """Save the model weights file.""" + if hub_model_name is None: + if model_name in _HUB_MODEL_NAMES: + hub_model_name = _HUB_MODEL_NAMES[model_name] + logger.debug("Mapping model name to torch hub equivalent: %s: %s", model_name, + hub_model_name) + else: + hub_model_name = model_name + model = torch.hub.load(hub_repo, hub_model_name, pretrained=True) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) + + +class DeiTShardForImageClassification(ModuleShard): + """Module shard based on `DeiTForImageClassification`.""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.deit = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner DeiT model + self.deit = DeiTModelShard(self.config, self.shard_config, weights) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) if self.config.num_labels > 0 else nn.Identity() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(weights["head.weight"])) + self.classifier.bias.copy_(torch.from_numpy(weights["head.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.deit(data) + if self.shard_config.is_last: + data = self.classifier(data[:, 0, :]) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, hub_repo: str='facebookresearch/deit:main', + hub_model_name: Optional[str]=None) -> None: + """Save the model weights file.""" + DeiTModelShard.save_weights(model_name, model_file, hub_repo=hub_repo, + hub_model_name=hub_model_name) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py new file mode 100644 index 00000000..7760f5e6 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py @@ -0,0 +1,232 @@ +"""ViT Transformers.""" +from collections.abc import Mapping +import logging +import math +import os +from typing import Optional, Union +import numpy as np +import requests +import torch +from torch import nn +from transformers import ViTConfig +from transformers.models.vit.modeling_vit import ( + ViTEmbeddings, ViTIntermediate, ViTOutput, ViTSelfAttention, ViTSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + +_WEIGHTS_URLS = { + 'google/vit-base-patch16-224': 'https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16-224.npz', + 'google/vit-large-patch16-224': 'https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16-224.npz', + 'google/vit-huge-patch14-224-in21k': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', +} + + +class ViTLayerShard(ModuleShard): + """Module shard based on `ViTLayer`.""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.layernorm_before = None + self.self_attention = None + self.self_output = None + self.layernorm_after = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.layernorm_before = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.self_attention = ViTSelfAttention(self.config) + if self.has_layer(1): + self.self_output = ViTSelfOutput(self.config) + if self.has_layer(2): + self.layernorm_after = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.intermediate = ViTIntermediate(self.config) + if self.has_layer(3): + self.output = ViTOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data_norm = self.layernorm_before(data) + data = (self.self_attention(data_norm)[0], data) + if self.has_layer(1): + skip = data[1] + data = self.self_output(data[0], skip) + data += skip + if self.has_layer(2): + data_norm = self.layernorm_after(data) + data = (self.intermediate(data_norm), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class ViTModelShard(ModuleShard): + """Module shard based on `ViTModel` (no pooling layer).""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # ViTModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a ViTEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.layernorm = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = ViTEmbeddings(self.config) + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = ViTLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load layernorm for the last shard") + self.layernorm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.cls_token.copy_(torch.from_numpy(weights["cls"])) + self.embeddings.position_embeddings.copy_(torch.from_numpy((weights["Transformer/posembed_input/pos_embedding"]))) + conv_weight = weights["embedding/kernel"] + # O, I, J, K = conv_weight.shape + # conv_weight = conv_weight.reshape(K,J,O,I) + conv_weight = conv_weight.transpose([3, 2, 0, 1]) + self.embeddings.patch_embeddings.projection.weight.copy_(torch.from_numpy(conv_weight)) + self.embeddings.patch_embeddings.projection.bias.copy_(torch.from_numpy(weights["embedding/bias"])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.layernorm.weight.copy_(torch.from_numpy(weights["Transformer/encoder_norm/scale"])) + self.layernorm.bias.copy_(torch.from_numpy(weights["Transformer/encoder_norm/bias"])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"Transformer/encoderblock_{layer_id}/" + hidden_size = self.config.hidden_size + if layer.has_layer(0): + layer.layernorm_before.weight.copy_(torch.from_numpy(weights[root + "LayerNorm_0/scale"])) + layer.layernorm_before.bias.copy_(torch.from_numpy(weights[root + "LayerNorm_0/bias"])) + layer.self_attention.query.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/query/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.key.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/key/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.value.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/value/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.query.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/query/bias"]).view(-1)) + layer.self_attention.key.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/key/bias"]).view(-1)) + layer.self_attention.value.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/value/bias"]).view(-1)) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/out/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/out/bias"]).view(-1)) + if layer.has_layer(2): + layer.layernorm_after.weight.copy_(torch.from_numpy(weights[root + "LayerNorm_2/scale"])) + layer.layernorm_after.bias.copy_(torch.from_numpy(weights[root + "LayerNorm_2/bias"])) + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_0/kernel"]).t()) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_0/bias"]).t()) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_1/kernel"]).t()) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_1/bias"]).t()) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.layernorm(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, url: Optional[str]=None, + timeout_sec: Optional[float]=None) -> None: + """Save the model weights file.""" + if url is None: + url = _WEIGHTS_URLS[model_name] + logger.info('Downloading model: %s: %s', model_name, url) + req = requests.get(url, stream=True, timeout=timeout_sec) + req.raise_for_status() + with open(model_file, 'wb') as file: + for chunk in req.iter_content(chunk_size=8192): + if chunk: + file.write(chunk) + file.flush() + os.fsync(file.fileno()) + + +class ViTShardForImageClassification(ModuleShard): + """Module shard based on `ViTForImageClassification`.""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.vit = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner ViT model + self.vit = ViTModelShard(self.config, self.shard_config, weights) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) if self.config.num_labels > 0 else nn.Identity() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(np.transpose(weights["head/kernel"]))) + self.classifier.bias.copy_(torch.from_numpy(weights["head/bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.vit(data) + if self.shard_config.is_last: + data = self.classifier(data[:, 0, :]) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, url: Optional[str]=None, + timeout_sec: Optional[float]=None) -> None: + """Save the model weights file.""" + ViTModelShard.save_weights(model_name, model_file, url=url, timeout_sec=timeout_sec) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py new file mode 100644 index 00000000..61fdf33f --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py @@ -0,0 +1,263 @@ +"""Module shard profiler.""" +import argparse +import gc +import os +import time +import numpy as np +import psutil +import torch +import torch.multiprocessing as mp +import yaml +from transformers import BertTokenizer +import devices +import model_cfg + + +def get_shapes(tensors): + """Get the tensor shapes, excluding the outer dimension (microbatch size).""" + if isinstance(tensors, tuple): + shape = [] + for tensor in tensors: + shape.append(tuple(tensor.shape[1:])) + else: + shape = [tuple(tensors.shape[1:])] + return shape + + +def create_module_shard(module_cfg, stage_cfg): + """Create a module shard.""" + model_name = module_cfg['name'] + model_file = module_cfg['file'] + stage = stage_cfg['stage'] + layer_start = stage_cfg['layer_start'] + layer_end = stage_cfg['layer_end'] + return model_cfg.module_shard_factory(model_name, model_file, layer_start, layer_end, stage) + + +def profile_module_shard(module_cfg, stage_cfg, stage_inputs, warmup, iterations): + """Profile a module shard.""" + process = psutil.Process(os.getpid()) + + # Measure memory (create shard) on the CPU. + # This avoids capturing additional memory overhead when using other devices, like GPUs. + # It's OK if the model fits in DRAM but not on the "device" - we'll just fail later. + # We consider memory requirements to be a property of the model, not the device/platform. + assert devices.DEVICE is None + # Capturing memory behavior in Python is extremely difficult and results are subject to many + # factors beyond our ability to control or reliably detect/infer. + # This works best when run once per process execution with only minimal work done beforehand. + gc.collect() + stage_start_mem = process.memory_info().rss / 1000000 + module = create_module_shard(module_cfg, stage_cfg) + gc.collect() + stage_end_mem = process.memory_info().rss / 1000000 + + # Now move the module to the specified device + device = module_cfg['device'] + if device is not None: + devices.DEVICE = torch.device(device) + if devices.DEVICE is not None and devices.DEVICE.type == 'cuda': + torch.cuda.init() + module.to(device=device) + module.register_forward_pre_hook(devices.forward_pre_hook_to_device) + module.register_forward_hook(devices.forward_hook_to_cpu) + + # Measure data input + shape_in = get_shapes(stage_inputs) + + # Optional warmup + if warmup: + module(stage_inputs) + + # Measure timing (execute shard) - includes data movement overhead (performed in hooks) + stage_times = [] + for _ in range(iterations): + stage_start_time = time.time() + stage_outputs = module(stage_inputs) + stage_end_time = time.time() + stage_times.append(stage_end_time - stage_start_time) + stage_time_avg = sum(stage_times) / len(stage_times) + + # Measure data output + shape_out = get_shapes(stage_outputs) + + results = { + 'shape_in': shape_in, + 'shape_out': shape_out, + 'memory': stage_end_mem - stage_start_mem, + 'time': stage_time_avg, + } + return (stage_outputs, results) + + +def profile_module_shard_mp_queue(queue, evt_done, args): + """Multiprocessing target function for `profile_module_shard` which adds output to queue.""" + queue.put(profile_module_shard(*args)) + evt_done.wait() + + +def profile_module_shard_mp(args): + """Run `profile_module_shard` with multiprocessing (for more accurate memory results).""" + # First, a non-optional module warmup in case PyTorch needs to fetch/cache models on first use + print("Performing module warmup...") + proc = mp.Process(target=create_module_shard, args=(args[0], args[1])) + proc.start() + proc.join() + + # Now, the actual profiling + print("Performing module profiling...") + queue = mp.Queue() + # The child process sometimes exits before we read the queue items, even though it should have + # flushed all data to the underlying pipe before that, so use an event to keep it alive. + evt_done = mp.Event() + proc = mp.Process(target=profile_module_shard_mp_queue, args=(queue, evt_done, args)) + proc.start() + tensors, prof_dict = queue.get() + evt_done.set() + proc.join() + return (tensors, prof_dict) + + +def profile_layers(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile a shard with layer_start through layer_end.""" + shard = { + 'stage': 0, + 'layer_start': layer_start, + 'layer_end': layer_end, + } + _, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = 0 + return [prof_dict] + + +def profile_layers_individually(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile module shards for each layer individually.""" + results = [] + for layer in range(layer_start, layer_end + 1): + shard = { + 'stage': layer, + 'layer_start': layer, + 'layer_end': layer, + } + tensors, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = layer + results.append(prof_dict) + return results + + +def profile_layers_cumulatively(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile module shards with increasing numbers of layers.""" + results = [] + for layer in range(1, layer_end + 1): + shard = { + 'stage': layer, + 'layer_start': layer_start, + 'layer_end': layer, + } + _, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = layer + results.append(prof_dict) + return results + + +def validate_profile_results(profile_results, args, inputs, model_layers, layer_end): + """Validate that we can work with existing profiling results""" + assert profile_results['model_name'] == args.model_name, "model name mismatch with existing results" + dtype = inputs[0].dtype if isinstance(inputs, tuple) else inputs.dtype + assert profile_results['dtype'] == str(dtype), "dtype mismatch with existing results" + assert profile_results['batch_size'] == args.batch_size, "batch size mismatch with existing results" + assert profile_results['layers'] == model_layers, "layer count mismatch with existing results" + # check for overlap with existing results data + for _layer in range(args.layer_start, layer_end + 1): + for _pd in profile_results['profile_data']: + assert _layer != _pd['layer'], "layer to be profiled already in existing results" + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Module Shard Profiler", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-o", "--results-yml", default="./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler_results.yml", type=str, + help="output YAML file") + parser.add_argument("-d", "--device", type=str, default=None, + help="compute device type to use, with optional ordinal, " + "e.g.: 'cpu', 'cuda', 'cuda:1'") + parser.add_argument("-m", "--model-name", type=str, default="google/vit-base-patch16-224", + choices=model_cfg.get_model_names(), + help="the neural network model for loading") + parser.add_argument("-M", "--model-file", type=str, + help="the model file, if not in working directory") + parser.add_argument("-l", "--layer-start", default=1, type=int, help="start layer") + parser.add_argument("-L", "--layer-end", type=int, help="end layer; default: last layer in the model") + parser.add_argument("-s", "--shape-input", type=str, action='append', + help="comma-delimited shape input, e.g., '3,224,224' (required for start_layer != 1)") + parser.add_argument("-b", "--batch-size", default=8, type=int, help="batch size") + parser.add_argument("-w", "--warmup", action="store_true", default=True, + help="perform a warmup iteration " + "(strongly recommended, esp. with device='cuda' or iterations>1)") + parser.add_argument("--no-warmup", action="store_false", dest="warmup", + help="don't perform a warmup iteration") + parser.add_argument("-i", "--iterations", default=1, type=int, + help="iterations to average runtime for") + args = parser.parse_args() + + if args.shape_input is not None: + shapes = [] + for shp in args.shape_input: + shapes.append(tuple(int(d) for d in shp.split(','))) + if len(shapes) > 1: + # tuple of tensors + inputs = tuple(torch.randn(args.batch_size, *shp) for shp in shapes) + else: + # single tensor + inputs = torch.randn(args.batch_size, *shapes[0]) + elif args.model_name in ['bert-base-uncased', 'bert-large-uncased']: + with np.load("bert_input.npz") as bert_inputs: + inputs_sentence = list(bert_inputs['input'][0: args.batch_size]) + tokenizer = BertTokenizer.from_pretrained(args.model_name) + inputs = tokenizer(inputs_sentence, padding=True, truncation=True, return_tensors="pt")['input_ids'] + else: + inputs = torch.randn(args.batch_size, 3, 224, 224) + + model_layers = model_cfg.get_model_layers(args.model_name) + layer_end = args.layer_end + if layer_end is None: + layer_end = model_layers + + # get or create profile_results + if os.path.exists(args.results_yml): + print("Using existing results file") + with open(args.results_yml, 'r', encoding='utf-8') as yfile: + profile_results = yaml.safe_load(yfile) + validate_profile_results(profile_results, args, inputs, model_layers, layer_end) + else: + profile_results = { + 'model_name': args.model_name, + 'dtype': str(inputs.dtype), + 'batch_size': args.batch_size, + 'layers': model_layers, + 'profile_data': [], + } + + module_cfg = { + 'device': args.device, + 'name': args.model_name, + 'file': args.model_file, + } + if args.model_file is None: + module_cfg['file'] = model_cfg.get_model_default_weights_file(args.model_name) + # a single shard measurement can be a useful reference + # results = profile_layers(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + # cumulative won't work if the whole model doesn't fit on the device + # results = profile_layers_cumulatively(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + results = profile_layers_individually(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + + # just a dump of the configuration and profiling results + profile_results['profile_data'].extend(results) + profile_results['profile_data'].sort(key=lambda pd: pd['layer']) + with open(args.results_yml, 'w', encoding='utf-8') as yfile: + yaml.safe_dump(profile_results, yfile, default_flow_style=None, encoding='utf-8') + + +if __name__=="__main__": + main() diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py new file mode 100644 index 00000000..b90459a2 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py @@ -0,0 +1,49 @@ +"""Manage YAML files.""" +import os +import yaml + + +def _yaml_load_map(file): + if os.path.exists(file): + with open(file, 'r', encoding='utf-8') as yfile: + yml = yaml.safe_load(yfile) + else: + yml = {} + return yml + + +def yaml_models_load(file) -> dict: + """Load a YAML models file.""" + # models files are a map of model names to yaml_model values. + return _yaml_load_map(file) + + +def yaml_device_types_load(file) -> dict: + """Load a YAML device types file.""" + # device types files are a map of device type names to yaml_device_type values. + return _yaml_load_map(file) + + +def yaml_devices_load(file) -> dict: + """Load a YAML devices file.""" + # devices files are a map of device type names to lists of hosts. + return _yaml_load_map(file) + + +def yaml_device_neighbors_load(file) -> dict: + """Load a YAML device neighbors file.""" + # device neighbors files are a map of neighbor hostnames to yaml_device_neighbors_type values. + return _yaml_load_map(file) + + +def yaml_device_neighbors_world_load(file) -> dict: + """Load a YAML device neighbors world file.""" + # device neighbors world files are a map of hostnames to a map of neighbor hostnames to + # yaml_device_neighbors_type values. + return _yaml_load_map(file) + + +def yaml_save(yml, file): + """Save a YAML file.""" + with open(file, 'w', encoding='utf-8') as yfile: + yaml.safe_dump(yml, yfile, default_flow_style=None, encoding='utf-8') diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py new file mode 100644 index 00000000..a5639fe5 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py @@ -0,0 +1,82 @@ +"""YAML types.""" +from typing import List, Optional, Union + + +def _assert_list_type(lst, dtype): + assert isinstance(lst, list) + for var in lst: + assert isinstance(var, dtype) + + +def yaml_model(num_layers: int, parameters_in: int, parameters_out: List[int], + mem_MB: Union[List[int], List[float]]) -> dict: + """Create a YAML model.""" + assert isinstance(num_layers, int) + assert isinstance(parameters_in, int) + _assert_list_type(parameters_out, int) + _assert_list_type(mem_MB, (int, float)) + return { + 'layers': num_layers, + 'parameters_in': parameters_in, + 'parameters_out': parameters_out, + 'mem_MB': mem_MB, + } + + +def yaml_model_profile(dtype: str, batch_size: int, time_s: Union[List[int], List[float]]) -> dict: + """Create a YAML model profile.""" + assert isinstance(dtype, str) + assert isinstance(batch_size, int) + _assert_list_type(time_s, (int, float)) + return { + 'dtype': dtype, + 'batch_size': batch_size, + 'time_s': time_s, + } + + +def _assert_model_profile(model_prof): + assert isinstance(model_prof, dict) + for model_prof_prop in model_prof: + # only 'time_s' is supported + assert model_prof_prop == 'time_s' + _assert_list_type(model_prof['time_s'], (int, float)) + + +def _assert_model_profiles(model_profiles): + assert isinstance(model_profiles, dict) + for model in model_profiles: + assert isinstance(model, str) + _assert_model_profile(model_profiles[model]) + + +def yaml_device_type(mem_MB: Union[int, float], bw_Mbps: Union[int, float], + model_profiles: Optional[dict]) -> dict: + """Create a YAML device type.""" + assert isinstance(mem_MB, (int, float)) + assert isinstance(bw_Mbps, (int, float)) + if model_profiles is None: + model_profiles = {} + _assert_model_profiles(model_profiles) + return { + 'mem_MB': mem_MB, + 'bw_Mbps': bw_Mbps, + 'model_profiles': model_profiles, + } + +def yaml_device_neighbors_type(bw_Mbps: Union[int, float]) -> dict: + """Create a YAML device neighbors type.""" + assert isinstance(bw_Mbps, (int, float)) + return { + 'bw_Mbps': bw_Mbps, + # Currently only one field, but could be extended, e.g., to include latency_{ms,us}. + } + +def yaml_device_neighbors(neighbors: List[str], bws_Mbps: Union[List[int], List[float]]) -> dict: + """Create a YAML device neighbors.""" + _assert_list_type(neighbors, str) + _assert_list_type(bws_Mbps, (int, float)) + return { + neighbor: yaml_device_neighbors_type(bw_Mbps) + for neighbor, bw_Mbps in zip(neighbors, bws_Mbps) + } diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py new file mode 100644 index 00000000..2efcb901 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py @@ -0,0 +1,154 @@ +# Modified Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +from collections import OrderedDict +from pathlib import Path +from collections import defaultdict +import time + +from sedna.common.class_factory import ClassType, ClassFactory +from dataset import load_dataset + +import yaml +import onnxruntime as ort +from torch.utils.data import DataLoader +import numpy as np +from tqdm import tqdm +import pynvml + + +__all__ = ["BaseModel"] + +# set backend +os.environ["BACKEND_TYPE"] = "ONNX" + + +def make_parser(): + parser = argparse.ArgumentParser("ViT Eval") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument("--devices_info", default="./devices.yaml", type=str, help="devices conf") + parser.add_argument("--model_parallel", default=True, action="store_true") + parser.add_argument("--split", default="val", type=str, help="split of dataset") + parser.add_argument("--indices", default=None, type=str, help="indices of dataset") + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle data") + parser.add_argument("--model_name", default="google/vit-base-patch16-224", type=str, help="model name") + parser.add_argument("--dataset_name", default="ImageNet", type=str, help="dataset name") + parser.add_argument("--data_size", default=1000, type=int, help="data size to inference") + # remove conflict with ianvs + parser.add_argument("-f") + return parser + + +@ClassFactory.register(ClassType.GENERAL, alias="Classification") +class BaseModel: + + def __init__(self, **kwargs) -> None: + self.args = make_parser().parse_args() + self.devices_info_url = str(Path(Path(__file__).parent.resolve(), self.args.devices_info)) + self.model_parallel = self.args.model_parallel + self.partition_point_list = self._parse_devices_info(self.devices_info_url).get('partition_points') + self.models = [] + return + + + def load(self, models_dir=None, map_info=None) -> None: + cnt = 0 + for model_name, device in map_info.items(): + model = models_dir + '/' + model_name + if not os.path.exists(model): + raise ValueError("=> No modle found at '{}'".format(model)) + if device == 'cpu': + session = ort.InferenceSession(model, providers=['CPUExecutionProvider']) + elif 'gpu' in device: + device_id = int(device.split('-')[-1]) + session = ort.InferenceSession(model, providers=[('CUDAExecutionProvider', {'device_id': device_id})]) + else: + raise ValueError("Error device info: '{}'".format(device)) + self.models.append({ + 'session': session, + 'name': model_name, + 'device': device, + 'input_names': self.partition_point_list[cnt]['input_names'], + 'output_names': self.partition_point_list[cnt]['output_names'], + }) + cnt += 1 + print("=> Loaded onnx model: '{}'".format(model)) + return + + def predict(self, data, input_shape=None, **kwargs): + pynvml.nvmlInit() + root = str(Path(data[0]).parents[2]) + dataset_cfg = { + 'name': self.args.dataset_name, + 'root': root, + 'split': self.args.split, + 'indices': self.args.indices, + 'shuffle': self.args.shuffle + } + data_loader, ids = self._get_eval_loader(dataset_cfg) + data_loader = tqdm(data_loader, desc='Evaluating', unit='batch') + pred = [] + inference_time_per_device = defaultdict(int) + power_usage_per_device = defaultdict(list) + mem_usage_per_device = defaultdict(list) + cnt = 0 + for data, id in zip(data_loader, ids): + outputs = data[0].numpy() + for model in self.models: + start_time = time.time() + outputs = model['session'].run(None, {model['input_names'][0]: outputs})[0] + end_time = time.time() + device = model.get('device') + inference_time_per_device[device] += end_time - start_time + if 'gpu' in device and cnt % 100 == 0: + handle = pynvml.nvmlDeviceGetHandleByIndex(int(device.split('-')[-1])) + power_usage = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle).used / (1024**2) + power_usage_per_device[device] += [power_usage] + mem_usage_per_device[device] += [memory_info] + max_ids = np.argmax(outputs) + pred.append((max_ids, id)) + cnt += 1 + data_loader.close() + result = dict({}) + result["pred"] = pred + result["inference_time_per_device"] = inference_time_per_device + result["power_usage_per_device"] = power_usage_per_device + result["mem_usage_per_device"] = mem_usage_per_device + return result + + + def _get_eval_loader(self, dataset_cfg): + model_name = self.args.model_name + data_size = self.args.data_size + dataset, _, ids = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + return data_loader, ids + + def _parse_devices_info(self, url): + """Convert yaml file to the dict.""" + if url.endswith('.yaml') or url.endswith('.yml'): + with open(url, "rb") as file: + devices_info_dict = yaml.load(file, Loader=yaml.SafeLoader) + return devices_info_dict + else: + raise RuntimeError('config file must be the yaml format') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml new file mode 100644 index 00000000..c8d212f7 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml @@ -0,0 +1,27 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "multiedgeinference" + # the url address of initial model; string type; optional; + initial_model_url: "./initial_model/vit-base-patch16-224.onnx" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "Classification" + # the url address of python module; string type; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 1 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py new file mode 100644 index 00000000..9b4ee16c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py @@ -0,0 +1,71 @@ +import logging +import random +from typing import Callable, Optional, Sequence +import os + +from torch.utils.data import DataLoader, Dataset, Subset +from transformers import ViTFeatureExtractor +from torchvision.datasets import ImageNet + + +def load_dataset_imagenet(feature_extractor: Callable, root: str, split: str='train') -> Dataset: + """Get the ImageNet dataset.""" + + def transform(img): + pixels = feature_extractor(images=img.convert('RGB'), return_tensors='pt')['pixel_values'] + return pixels[0] + return ImageNet(root, split=split, transform=transform) + +def load_dataset_subset(dataset: Dataset, indices: Optional[Sequence[int]]=None, + max_size: Optional[int]=None, shuffle: bool=False) -> Dataset: + """Get a Dataset subset.""" + if indices is None: + indices = list(range(len(dataset))) + if shuffle: + random.shuffle(indices) + if max_size is not None: + indices = indices[:max_size] + image_paths = [] + for index in indices: + image_paths.append(dataset.imgs[index][0]) + return Subset(dataset, indices), image_paths, indices + +def load_dataset(dataset_cfg: dict, model_name: str, batch_size: int) -> Dataset: + """Load inputs based on model.""" + def _get_feature_extractor(): + feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) + return feature_extractor + dataset_name = dataset_cfg['name'] + dataset_root = dataset_cfg['root'] + dataset_split = dataset_cfg['split'] + indices = dataset_cfg['indices'] + dataset_shuffle = dataset_cfg['shuffle'] + if dataset_name == 'ImageNet': + if dataset_root is None: + dataset_root = 'ImageNet' + logging.info("Dataset root not set, assuming: %s", dataset_root) + feature_extractor = _get_feature_extractor() + dataset = load_dataset_imagenet(feature_extractor, dataset_root, split=dataset_split) + dataset, paths, ids = load_dataset_subset(dataset, indices=indices, max_size=batch_size, + shuffle=dataset_shuffle) + return dataset, paths, ids + +if __name__ == '__main__': + dataset_cfg = { + 'name': "ImageNet", + 'root': './dataset', + 'split': 'val', + 'indices': None, + 'shuffle': False, + } + model_name = "google/vit-base-patch16-224" + ## Total images to be inferenced. + data_size = 1000 + dataset, paths, _ = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + with open('./dataset/train.txt', 'w') as f: + for i, (image, label) in enumerate(data_loader): + original_path = paths[i].replace('/dataset', '') + f.write(f"{original_path} {label.item()}\n") + f.close() + os.popen('cp ./dataset/train.txt ./dataset/test.txt') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml new file mode 100644 index 00000000..82be6c7c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml @@ -0,0 +1,26 @@ +devices: + - name: "gpu" + type: "gpu-0" + memory: "1024" + freq: "2.6" + bandwith: "100" + - name: "gpu-1" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "80" + - name: "gpu-2" + type: "gpu" + memory: "2048" + freq: "2.6" + bandwith: "90" +partition_points: + - input_names: ["pixel_values"] + output_names: ["input.60"] + device_name: "gpu-0" + - input_names: ["input.60"] + output_names: ["input.160"] + device_name: "gpu-1" + - input_names: ["input.160"] + output_names: ["logits"] + device_name: "gpu-2" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml new file mode 100644 index 00000000..2f4a5640 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml @@ -0,0 +1,11 @@ +devices: + - name: "gpu-1" + type: gpu" + memory: "1024" + freq: "2.6" + bandwith: "100" + +partition_points: + - input_names: ["pixel_values"] + output_names: ["logits"] + device_name: "gpu-1" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py b/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py new file mode 100644 index 00000000..86f78d69 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py @@ -0,0 +1,14 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ('accuracy') + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + y_pred = y_pred.get("pred") + total = len(y_pred) + y_true_ = [int(y_true[i].split('/')[-1]) for (_, i) in y_pred] + y_pred_ = [int(i) for (i, _) in y_pred] + correct_predictions = sum(yt == yp for yt, yp in zip(y_true_, y_pred_)) + accuracy = (correct_predictions / total) * 100 + print("Accuracy: {:.2f}%".format(accuracy)) + return accuracy diff --git a/examples/imagenet/multiedge_inference_bench/testenv/fps.py b/examples/imagenet/multiedge_inference_bench/testenv/fps.py new file mode 100644 index 00000000..810e72f6 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/fps.py @@ -0,0 +1,34 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('fps') + +@ClassFactory.register(ClassType.GENERAL, alias="fps") +def fps(y_true, y_pred, **kwargs): + total = len(y_pred.get("pred")) + inference_time_per_device = y_pred.get("inference_time_per_device") + plt.figure() + min_fps = sys.maxsize + for device, time in inference_time_per_device.items(): + fps = total / time + plt.bar(device, fps, label=f'{device}') + min_fps = min(fps, min_fps) + plt.axhline(y=min_fps, color='red', linewidth=2, label='Min FPS') + + plt.xticks(rotation=45) + plt.ylabel('FPS') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'FPS_per_device' + now + '.png') + + return min_fps diff --git a/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py b/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py new file mode 100644 index 00000000..c6e6dccb --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py @@ -0,0 +1,32 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('peak_memory') + +@ClassFactory.register(ClassType.GENERAL, alias="peak_memory") +def peak_power(y_true, y_pred, **kwargs): + mem_usage_per_device = y_pred.get("mem_usage_per_device") + plt.figure() + peak_mem = -sys.maxsize + for device, mem_list in mem_usage_per_device.items(): + plt.bar(device, max(mem_list), label=f'{device}') + peak_mem = max(peak_mem, max(mem_list)) + plt.axhline(y=peak_mem, color='red', linewidth=2, label='Peak Memory') + + plt.xticks(rotation=45) + plt.ylabel('Memory') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'peak_mem_per_device' + now + '.png') + + return peak_mem diff --git a/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py b/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py new file mode 100644 index 00000000..7618e714 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py @@ -0,0 +1,32 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('peak_power') + +@ClassFactory.register(ClassType.GENERAL, alias="peak_power") +def peak_power(y_true, y_pred, **kwargs): + power_usage_per_device = y_pred.get("power_usage_per_device") + plt.figure() + peak_power = -sys.maxsize + for device, power_list in power_usage_per_device.items(): + plt.plot(power_list, label=device) + peak_power = max(peak_power, max(power_list)) + plt.axhline(y=peak_power, color='red', linewidth=2, label='Peak Power') + + plt.xticks(rotation=45) + plt.ylabel('Power') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'power_usage_per_device' + now + '.png') + + return peak_power diff --git a/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml b/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml new file mode 100644 index 00000000..f2bb70d2 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml @@ -0,0 +1,25 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_url: "./dataset/train.txt" + # the url address of test dataset index; string type; + test_url: "./dataset/test.txt" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/accuracy.py" + - name: "fps" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/fps.py" + - name: "peak_memory" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py" + - name: "peak_power" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/peak_power.py" + devices: + - url : "./devices.yaml" \ No newline at end of file