From c58085136e160cd164013bd263be5a74a3e7999a Mon Sep 17 00:00:00 2001 From: Cindy Jiang <47068112+cindyyuanjiang@users.noreply.github.com> Date: Mon, 15 Jul 2024 11:59:59 -0700 Subject: [PATCH] Add internal CLI to generate instance descriptions for CSPs (#1137) * refactored code for adding new cli to generate instance description files for csps Signed-off-by: cindyyuanjiang * fixed python style Signed-off-by: cindyyuanjiang * addressed review feedback Signed-off-by: cindyyuanjiang * fix python style Signed-off-by: cindyyuanjiang * function return type Signed-off-by: cindyyuanjiang * simplified instance description json structure Signed-off-by: cindyyuanjiang * update json key to VCpuCount Signed-off-by: cindyyuanjiang * fixed case when user give non-exist output folder Signed-off-by: cindyyuanjiang * add gpu info for n1 series dataproc Signed-off-by: cindyyuanjiang * fixed python style Signed-off-by: cindyyuanjiang * cleaned up comments Signed-off-by: cindyyuanjiang * fix issue with databricks azure platform input Signed-off-by: cindyyuanjiang * update gpu count to list for consistency Signed-off-by: cindyyuanjiang * updated comment for gpu count to list Signed-off-by: cindyyuanjiang --------- Signed-off-by: cindyyuanjiang --- user_tools/pyproject.toml | 1 + .../cloud_api/databricks_azure.py | 52 ++++++-------- .../cloud_api/dataproc.py | 69 ++++++++++++++++++- .../src/spark_rapids_pytools/cloud_api/emr.py | 17 +++++ .../cloud_api/sp_types.py | 33 +++++++++ .../common/prop_manager.py | 10 ++- .../rapids/dev/instance_description.py | 64 +++++++++++++++++ .../src/spark_rapids_tools/cmdli/__init__.py | 6 +- .../spark_rapids_tools/cmdli/argprocessor.py | 30 +++++++- .../src/spark_rapids_tools/cmdli/dev_cli.py | 63 +++++++++++++++++ .../src/spark_rapids_tools/utils/util.py | 14 ++-- 11 files changed, 315 insertions(+), 44 deletions(-) create mode 100644 user_tools/src/spark_rapids_pytools/rapids/dev/instance_description.py create mode 100644 user_tools/src/spark_rapids_tools/cmdli/dev_cli.py diff --git a/user_tools/pyproject.toml b/user_tools/pyproject.toml index c6cf81795..0c4b8d19c 100644 --- a/user_tools/pyproject.toml +++ b/user_tools/pyproject.toml @@ -62,6 +62,7 @@ dynamic=["entry-points", "version"] [project.scripts] spark_rapids_user_tools = "spark_rapids_pytools.wrapper:main" spark_rapids = "spark_rapids_tools.cmdli.tools_cli:main" +spark_rapids_dev = "spark_rapids_tools.cmdli.dev_cli:main" [tool.setuptools] package-dir = {"" = "src"} diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py b/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py index 0174af5eb..c44d1e367 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/databricks_azure.py @@ -174,38 +174,28 @@ def _build_cmd_scp_from_node(self, node: ClusterNode, src: str, dest: str) -> st dest] return Utils.gen_joined_str(' ', prefix_args) - def process_instances_description(self, raw_instances_description: str) -> dict: - processed_instances_description = {} - instances_description = JSONPropertiesContainer(prop_arg=raw_instances_description, file_load=False) - for instance in instances_description.props: - instance_dict = {} - v_cpus = 0 - memory_gb = 0 - gpus = 0 + def _process_instance_description(self, instance_descriptions: str) -> dict: + processed_instance_descriptions = {} + raw_instance_descriptions = JSONPropertiesContainer(prop_arg=instance_descriptions, file_load=False) + for instance in raw_instance_descriptions.props: if not instance['capabilities']: continue - for item in instance['capabilities']: - if item['name'] == 'vCPUs': - v_cpus = int(item['value']) - elif item['name'] == 'MemoryGB': - memory_gb = int(float(item['value']) * 1024) - elif item['name'] == 'GPUs': - gpus = int(item['value']) - instance_dict['VCpuInfo'] = {'DefaultVCpus': v_cpus} - instance_dict['MemoryInfo'] = {'SizeInMiB': memory_gb} - if gpus > 0: - gpu_list = [{'Name': '', 'Manufacturer': '', 'Count': gpus, 'MemoryInfo': {'SizeInMiB': 0}}] - instance_dict['GpuInfo'] = {'GPUs': gpu_list} - processed_instances_description[instance['name']] = instance_dict - return processed_instances_description - - def generate_instances_description(self, fpath: str): - cmd_params = ['az vm list-skus', - '--location', f'{self.get_region()}'] - raw_instances_description = self.run_sys_cmd(cmd_params) - json_instances_description = self.process_instances_description(raw_instances_description) - with open(fpath, 'w', encoding='UTF-8') as output_file: - json.dump(json_instances_description, output_file, indent=2) + instance_content = {} + gpu_count = 0 + for elem in instance['capabilities']: + if elem['name'] == 'vCPUs': + instance_content['VCpuCount'] = int(elem['value']) + elif elem['name'] == 'MemoryGB': + instance_content['MemoryInMB'] = int(float(elem['value']) * 1024) + elif elem['name'] == 'GPUs': + gpu_count = int(elem['value']) + if gpu_count > 0: + instance_content['GpuInfo'] = [{'Count': [gpu_count]}] + processed_instance_descriptions[instance['name']] = instance_content + return processed_instance_descriptions + + def get_instance_description_cli_params(self): + return ['az vm list-skus', '--location', f'{self.get_region()}'] def _build_platform_describe_node_instance(self, node: ClusterNode) -> list: pass @@ -224,7 +214,7 @@ def init_instances_description(self) -> str: fpath = FSUtil.build_path(cache_dir, 'azure-instances-catalog.json') if self._caches_expired(fpath): self.logger.info('Downloading the Azure instance type descriptions catalog') - self.generate_instances_description(fpath) + self.generate_instance_description(fpath) else: self.logger.info('The Azure instance type descriptions catalog is loaded from the cache') return fpath diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py index 76447bfb8..79a73a884 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/dataproc.py @@ -25,7 +25,7 @@ from spark_rapids_pytools.cloud_api.sp_types import PlatformBase, CMDDriverBase, \ ClusterBase, ClusterNode, SysInfo, GpuHWInfo, SparkNodeType, ClusterState, GpuDevice, \ NodeHWInfo, ClusterGetAccessor -from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer +from spark_rapids_pytools.common.prop_manager import JSONPropertiesContainer, is_valid_gpu_device from spark_rapids_pytools.common.sys_storage import FSUtil from spark_rapids_pytools.common.utilities import SysCmd, Utils from spark_rapids_pytools.pricing.dataproc_pricing import DataprocPriceProvider @@ -291,6 +291,73 @@ def get_submit_spark_job_cmd_for_cluster(self, cmd.extend(jar_args) return cmd + def _process_instance_description(self, instance_descriptions: str) -> dict: + def extract_gpu_name(gpu_description: str) -> str: + gpu_name = '' + for elem in gpu_description.split('-'): + if is_valid_gpu_device(elem): + gpu_name = elem + break + return gpu_name.upper() + + processed_instance_descriptions = {} + raw_instances_descriptions = JSONPropertiesContainer(prop_arg=instance_descriptions, file_load=False) + for instance in raw_instances_descriptions.props: + instance_content = {} + instance_content['VCpuCount'] = int(instance.get('guestCpus', -1)) + instance_content['MemoryInMB'] = int(instance.get('memoryMb', -1)) + if 'accelerators' in instance: + raw_accelerator_info = instance['accelerators'][0] + gpu_name = extract_gpu_name(raw_accelerator_info.get('guestAcceleratorType')) + if gpu_name != '': + gpu_count = int(raw_accelerator_info.get('guestAcceleratorCount', -1)) + gpu_info = {'Name': gpu_name, 'Count': [gpu_count]} + instance_content['GpuInfo'] = [gpu_info] + processed_instance_descriptions[instance.get('name')] = instance_content + + # for Dataproc, some instance types can attach customized GPU devices + # Ref: https://cloud.google.com/compute/docs/gpus#n1-gpus + for instance_name, instance_info in processed_instance_descriptions.items(): + if instance_name.startswith('n1-standard'): + if 'GpuInfo' not in instance_info: + instance_info['GpuInfo'] = [] + # N1 + T4 GPUs + if 1 <= instance_info['VCpuCount'] <= 48: + t4_gpu_info = {'Name': 'T4', 'Count': [1, 2, 4]} + else: # 48 < VCpuCount <= 96 + t4_gpu_info = {'Name': 'T4', 'Count': [4]} + instance_info['GpuInfo'].append(t4_gpu_info) + # N1 + P4 GPUs + if 1 <= instance_info['VCpuCount'] <= 24: + p4_gpu_info = {'Name': 'P4', 'Count': [1, 2, 4]} + elif 24 < instance_info['VCpuCount'] <= 48: + p4_gpu_info = {'Name': 'P4', 'Count': [2, 4]} + else: # 48 < VCpuCount <= 96 + p4_gpu_info = {'Name': 'P4', 'Count': [4]} + instance_info['GpuInfo'].append(p4_gpu_info) + # N1 + V100 GPUs + if 1 <= instance_info['VCpuCount'] <= 12: + v100_gpu_info = {'Name': 'V100', 'Count': [1, 2, 4, 8]} + elif 12 < instance_info['VCpuCount'] <= 24: + v100_gpu_info = {'Name': 'V100', 'Count': [2, 4, 8]} + elif 24 < instance_info['VCpuCount'] <= 48: + v100_gpu_info = {'Name': 'V100', 'Count': [4, 8]} + else: # 48 < VCpuCount <= 96 + v100_gpu_info = {'Name': 'V100', 'Count': [8]} + instance_info['GpuInfo'].append(v100_gpu_info) + # N1 + P100 GPUs + if 1 <= instance_info['VCpuCount'] <= 16: + p100_gpu_info = {'Name': 'P100', 'Count': [1, 2, 4]} + elif 16 < instance_info['VCpuCount'] <= 32: + p100_gpu_info = {'Name': 'P100', 'Count': [2, 4]} + else: # 32 < VCpuCount <= 96 + p100_gpu_info = {'Name': 'P100', 'Count': [4]} + instance_info['GpuInfo'].append(p100_gpu_info) + return processed_instance_descriptions + + def get_instance_description_cli_params(self) -> list: + return ['gcloud compute machine-types list', '--zones', f'{self.get_zone()}'] + @dataclass class DataprocNode(ClusterNode): diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/emr.py b/user_tools/src/spark_rapids_pytools/cloud_api/emr.py index 32e8314be..dae6bdfba 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/emr.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/emr.py @@ -249,6 +249,23 @@ def _exec_platform_describe_node_instance(self, node: ClusterNode) -> str: def get_submit_spark_job_cmd_for_cluster(self, cluster_name: str, submit_args: dict) -> List[str]: raise NotImplementedError + def _process_instance_description(self, instance_descriptions: str) -> dict: + processed_instance_descriptions = {} + raw_instances_descriptions = JSONPropertiesContainer(prop_arg=instance_descriptions, file_load=False) + for instance in raw_instances_descriptions.get_value('InstanceTypes'): + instance_content = {} + instance_content['VCpuCount'] = int(instance.get('VCpuInfo', {}).get('DefaultVCpus', -1)) + instance_content['MemoryInMB'] = int(instance.get('MemoryInfo', {}).get('SizeInMiB', -1)) + if 'GpuInfo' in instance: + gpu_name = instance['GpuInfo']['Gpus'][0]['Name'] + gpu_count = int(instance['GpuInfo']['Gpus'][0]['Count']) + instance_content['GpuInfo'] = [{'Name': gpu_name, 'Count': [gpu_count]}] + processed_instance_descriptions[instance.get('InstanceType')] = instance_content + return processed_instance_descriptions + + def get_instance_description_cli_params(self): + return ['aws ec2 describe-instance-types', '--region', f'{self.get_region()}'] + @dataclass class InstanceGroup: diff --git a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py index eb2626122..33e5d2216 100644 --- a/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py +++ b/user_tools/src/spark_rapids_pytools/cloud_api/sp_types.py @@ -605,6 +605,39 @@ def get_submit_spark_job_cmd_for_cluster(self, submit_args: dict) -> List[str]: raise NotImplementedError + def _process_instance_description(self, instance_descriptions: str) -> dict: + raise NotImplementedError + + def get_instance_description_cli_params(self) -> List[str]: + raise NotImplementedError + + def generate_instance_description(self, fpath: str) -> None: + """ + Generates CSP instance type descriptions and store them in a json file. + Json file entry example ('GpuInfo' is optional): + { + "instance_name": { + "VCpuCount": 000, + "MemoryInMB": 000, + "GpuInfo": [ + { + "Name": gpu_name, + "Count": [ + 000 + ] + } + ] + } + } + :param fpath: the output json file path. + :return: + """ + cmd_params = self.get_instance_description_cli_params() + raw_instance_descriptions = self.run_sys_cmd(cmd_params) + json_instance_descriptions = self._process_instance_description(raw_instance_descriptions) + with open(fpath, 'w', encoding='UTF-8') as output_file: + json.dump(json_instance_descriptions, output_file, indent=2) + def __post_init__(self): self.logger = ToolLogging.get_and_setup_logger('rapids.tools.cmd_driver') diff --git a/user_tools/src/spark_rapids_pytools/common/prop_manager.py b/user_tools/src/spark_rapids_pytools/common/prop_manager.py index fdafcc71f..e8e2b2e60 100644 --- a/user_tools/src/spark_rapids_pytools/common/prop_manager.py +++ b/user_tools/src/spark_rapids_pytools/common/prop_manager.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,6 +46,14 @@ def to_camel_case(word: str) -> str: return res +def get_gpu_device_list() -> list: + return ['T4', 'V100', 'K80', 'A100', 'P100', 'A10', 'A10G', 'P4', 'L4', 'H100'] + + +def is_valid_gpu_device(val) -> bool: + return val.upper() in get_gpu_device_list() + + @dataclass class AbstractPropertiesContainer(object): """ diff --git a/user_tools/src/spark_rapids_pytools/rapids/dev/instance_description.py b/user_tools/src/spark_rapids_pytools/rapids/dev/instance_description.py new file mode 100644 index 000000000..6284375a8 --- /dev/null +++ b/user_tools/src/spark_rapids_pytools/rapids/dev/instance_description.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation class representing wrapper around the RAPIDS acceleration Prediction tool.""" + +import os +from dataclasses import dataclass + +from spark_rapids_pytools.rapids.rapids_tool import RapidsTool +from spark_rapids_pytools.common.sys_storage import FSUtil +from spark_rapids_pytools.common.utilities import Utils +from spark_rapids_pytools.cloud_api.sp_types import get_platform +from spark_rapids_pytools.rapids.tool_ctxt import ToolContext + + +@dataclass +class InstanceDescription(RapidsTool): + """Wrapper layer around Generate_Instance_Description Tool.""" + + name = 'instance_description' + instance_file = '' # local absolute path of the instance description file + + def _connect_to_execution_cluster(self) -> None: + pass + + def _collect_result(self) -> None: + pass + + def _archive_phase(self) -> None: + pass + + def _init_ctxt(self) -> None: + """ + Initialize the tool context, reusing qualification configurations. + """ + self.config_path = Utils.resource_path('qualification-conf.yaml') + self.ctxt = ToolContext(platform_cls=get_platform(self.platform_type), + platform_opts=self.wrapper_options.get('platformOpts'), + prop_arg=self.config_path, + name=self.name) + + def _process_output_args(self) -> None: + self.logger.debug('Processing Output Arguments') + if self.output_folder is None: + self.output_folder = Utils.get_rapids_tools_env('OUTPUT_DIRECTORY', os.getcwd()) + # make sure that output_folder is being absolute + self.output_folder = FSUtil.get_abs_path(self.output_folder) + FSUtil.make_dirs(self.output_folder) + self.instance_file = f'{self.output_folder}/{self.platform_type}-instance-catalog.json' + self.logger.debug('Instance description output will be saved in: %s', self.instance_file) + + def _run_rapids_tool(self) -> None: + self.ctxt.platform.cli.generate_instance_description(self.instance_file) diff --git a/user_tools/src/spark_rapids_tools/cmdli/__init__.py b/user_tools/src/spark_rapids_tools/cmdli/__init__.py index 1344d5a37..051a02aec 100644 --- a/user_tools/src/spark_rapids_tools/cmdli/__init__.py +++ b/user_tools/src/spark_rapids_tools/cmdli/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,10 @@ """init file of the user CLI used to run the tools""" +from .dev_cli import DevCLI from .tools_cli import ToolsCLI __all__ = [ - 'ToolsCLI' + 'ToolsCLI', + 'DevCLI' ] diff --git a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py index 21142eeb0..d3cc5c772 100644 --- a/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py +++ b/user_tools/src/spark_rapids_tools/cmdli/argprocessor.py @@ -109,8 +109,8 @@ class AbsToolUserArgModel: validator_name: ClassVar[str] = None @classmethod - def create_tool_args(cls, validator_arg: Union[str, dict], - *args: Any, **kwargs: Any) -> Optional[dict]: + def create_tool_args(cls, validator_arg: Union[str, dict], *args: Any, cli_class: str = 'ToolsCLI', + cli_name: str = 'spark_rapids', **kwargs: Any) -> Optional[dict]: """ A factory method to create the tool arguments based on the validator argument. :param validator_arg: Union type to accept either a dictionary or a string. This is required @@ -134,7 +134,7 @@ def create_tool_args(cls, validator_arg: Union[str, dict], return new_obj.build_tools_args() except (ValidationError, PydanticCustomError) as e: impl_class.logger.error('Validation err: %s\n', e) - dump_tool_usage(tool_name) + dump_tool_usage(cli_class, cli_name, tool_name) return None def get_eventlogs(self) -> Optional[str]: @@ -756,3 +756,27 @@ def build_tools_args(self) -> dict: 'base_model': self.base_model, 'platformOpts': {}, } + + +@dataclass +@register_tool_arg_validator('generate_instance_description') +class InstanceDescriptionUserArgModel(AbsToolUserArgModel): + """ + Represents the arguments to run the generate_instance_description tool. + """ + target_platform: str = None + accepted_platforms = ['dataproc', 'emr', 'databricks-azure'] + + def validate_platform(self) -> None: + if self.target_platform not in self.accepted_platforms: + raise PydanticCustomError('invalid_argument', + f'Platform \'{self.target_platform}\' is not in ' + + f'accepted platform list: {self.accepted_platforms}.') + + def build_tools_args(self) -> dict: + self.validate_platform() + return { + 'targetPlatform': CspEnv(self.target_platform), + 'output_folder': self.output_folder, + 'platformOpts': {}, + } diff --git a/user_tools/src/spark_rapids_tools/cmdli/dev_cli.py b/user_tools/src/spark_rapids_tools/cmdli/dev_cli.py new file mode 100644 index 000000000..f03064377 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/cmdli/dev_cli.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI to run development related tools.""" + + +import fire + +from spark_rapids_tools.cmdli.argprocessor import AbsToolUserArgModel +from spark_rapids_tools.utils.util import gen_app_banner, init_environment +from spark_rapids_pytools.common.utilities import ToolLogging +from spark_rapids_pytools.rapids.dev.instance_description import InstanceDescription + + +class DevCLI(object): # pylint: disable=too-few-public-methods + """CLI to run development related tools (for internal use only).""" + + def generate_instance_description(self, + platform: str = None, + output_folder: str = None): + """The generate_instance_description cmd takes a platform and generates a json file with all the + instance type descriptions for that CSP platform. + + :param platform: defines one of the following "dataproc", "emr", and "databricks-azure". + :param output_folder: local path to store the output. + """ + # Since this is an internal tool, we enable debug mode by default + ToolLogging.enable_debug_mode() + + init_environment('generate_instance_description') + + instance_description_args = AbsToolUserArgModel.create_tool_args('generate_instance_description', + cli_class='DevCLI', + cli_name='spark_rapids_dev', + target_platform=platform, + output_folder=output_folder) + if instance_description_args: + tool_obj = InstanceDescription(platform_type=instance_description_args['targetPlatform'], + output_folder=instance_description_args['output_folder'], + wrapper_options=instance_description_args) + tool_obj.launch() + + +def main(): + # Make Python Fire not use a pager when it prints a help text + fire.core.Display = lambda lines, out: out.write('\n'.join(lines) + '\n') + print(gen_app_banner('Development')) + fire.Fire(DevCLI()) + + +if __name__ == '__main__': + main() diff --git a/user_tools/src/spark_rapids_tools/utils/util.py b/user_tools/src/spark_rapids_tools/utils/util.py index 327e92861..ab90c34ea 100644 --- a/user_tools/src/spark_rapids_tools/utils/util.py +++ b/user_tools/src/spark_rapids_tools/utils/util.py @@ -97,13 +97,13 @@ def to_snake_case(word: str) -> str: return ''.join(['_' + i.lower() if i.isupper() else i for i in word]).lstrip('_') -def dump_tool_usage(tool_name: Optional[str], raise_sys_exit: Optional[bool] = True): - imported_module = __import__('spark_rapids_tools.cmdli', globals(), locals(), ['ToolsCLI']) - wrapper_clzz = getattr(imported_module, 'ToolsCLI') - help_name = 'spark_rapids' +def dump_tool_usage(cli_class: Optional[str], cli_name: Optional[str], tool_name: Optional[str], + raise_sys_exit: Optional[bool] = True) -> None: + imported_module = __import__('spark_rapids_tools.cmdli', globals(), locals(), [cli_class]) + wrapper_clzz = getattr(imported_module, cli_class) usage_cmd = f'{tool_name} -- --help' try: - fire.Fire(wrapper_clzz(), name=help_name, command=usage_cmd) + fire.Fire(wrapper_clzz(), name=cli_name, command=usage_cmd) except fire.core.FireExit: # ignore the sys.exit(0) thrown by the help usage. # ideally we want to exit with error @@ -112,12 +112,13 @@ def dump_tool_usage(tool_name: Optional[str], raise_sys_exit: Optional[bool] = T sys.exit(1) -def gen_app_banner() -> str: +def gen_app_banner(mode: str = '') -> str: """ ASCII Art is generated by an online Test-to-ASCII Art generator tool https://patorjk.com/software/taag :return: a string representing the banner of the user tools including the version """ + tool_mode_note = '' if mode == '' else f'{mode} CMD' c_ver = spark_rapids_pytools.__version__ return rf""" @@ -135,6 +136,7 @@ def gen_app_banner() -> str: * \____/____/\___/_/ /_/ \____/\____/_/____/ * * * * Version. {c_ver} * +*{'':>38}{tool_mode_note:<28}* * * * NVIDIA Corporation * * spark-rapids-support@nvidia.com *