diff --git a/src/benchmark_llm_serving/utils_args.py b/src/benchmark_llm_serving/utils_args.py index 2b3bfa3..aaabdb7 100644 --- a/src/benchmark_llm_serving/utils_args.py +++ b/src/benchmark_llm_serving/utils_args.py @@ -31,6 +31,7 @@ def get_parser_base_arguments() -> argparse.ArgumentParser: parser.add_argument("--backend", type=str, default="happy_vllm", help="The backend of the API we query") parser.add_argument("--model", type=str, help="The name of the model needed to query the completions API") parser.add_argument("--model-name", type=str, help="The name of the model to be displayed in the graphs") + parser.add_argument("--gpu-name", type=str, help="The name of the GPU on which the model is") parser.add_argument("--max-duration", type=int, default=900, help="The maximal duration (in s) between the beginning of the queries and the end of the queries") parser.add_argument("--min-duration", type=int, help="The minimal duration during which the benchmark should run if there are still some prompts available") parser.add_argument("--target-queries-nb", type=int, help="If min-duration is reached and this number is reached, stop the benchmark") diff --git a/tests/test_utils_args.py b/tests/test_utils_args.py index ea9726e..61e9012 100644 --- a/tests/test_utils_args.py +++ b/tests/test_utils_args.py @@ -6,7 +6,7 @@ def test_get_parser_base_arguments(): parser = utils_args.get_parser_base_arguments() base_arguments = {"--dataset-folder", "--base-url", "--host", "--port", "--step-live-metrics", "--max-queries", "--completions-endpoint", "--metrics-endpoint", "--info-endpoint", "--launch-arguments-endpoint", - "--backend", "--model", "--max-duration", "--min-duration", "--target-queries-nb", "--help", "-h", "--model-name"} + "--backend", "--model", "--max-duration", "--min-duration", "--target-queries-nb", "--help", "-h", "--model-name", "--gpu-name"} assert set(parser.__dict__["_option_string_actions"]) == base_arguments