diff --git a/genai-perf/genai_perf/metrics/llm_metrics.py b/genai-perf/genai_perf/metrics/llm_metrics.py index 13dff8a6..7dd00ba7 100755 --- a/genai-perf/genai_perf/metrics/llm_metrics.py +++ b/genai-perf/genai_perf/metrics/llm_metrics.py @@ -54,7 +54,7 @@ def __init__( time_to_first_tokens: List[int] = [], inter_token_latencies: List[int] = [], output_token_throughputs: List[float] = [], - output_token_throughputs_per_request: List[int] = [], + output_token_throughputs_per_request: List[float] = [], output_sequence_lengths: List[int] = [], input_sequence_lengths: List[int] = [], chunked_inter_token_latencies: List[List[int]] = [[]], diff --git a/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py index 183f21fd..05735518 100755 --- a/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py +++ b/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py @@ -204,6 +204,8 @@ def _get_input_token_count(self, req_inputs: dict) -> int: """Deserialize the request input and return tokenized inputs.""" if self._service_kind == "triton": input_text = req_inputs["text_input"] + elif self._service_kind == "triton_c_api": + return len(req_inputs["input_ids"]) # no tokenizer required elif self._service_kind == "openai": input_text = self._get_openai_input_text(req_inputs) else: @@ -232,6 +234,9 @@ def _get_output_token_counts( """Return response-level token counts and total token count.""" if self._service_kind == "triton": output_texts = self._get_triton_output_tokens(res_outputs) + elif self._service_kind == "triton_c_api": + # No tokenizer is need to get the token counts. + return self._get_tensorrtllm_engine_token_counts(res_outputs) elif self._service_kind == "openai": output_texts = self._get_openai_output_tokens(res_outputs) else: @@ -243,6 +248,17 @@ def _get_output_token_counts( output_token_counts = list(map(len, output_tokens)) return output_token_counts, full_text_token_count + def _get_tensorrtllm_engine_token_counts( + self, res_outputs: List[Dict] + ) -> Tuple[List[int], int]: + token_ids = [] + for r in res_outputs: + if isinstance(r["output_ids"], list): + token_ids += r["output_ids"] + else: + token_ids.append(r["output_ids"]) + return token_ids, len(token_ids) + def _get_triton_output_tokens(self, res_outputs: List[Dict]) -> List[str]: """Return a list of Triton response texts.""" return [r["text_output"] for r in res_outputs] diff --git a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index 74eb48a2..245afb2c 100755 --- a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -98,6 +98,8 @@ def _get_profile_metadata(self, data: dict) -> None: elif self._service_kind == "triton": self._response_format = ResponseFormat.TRITON + elif self._service_kind == "triton_c_api": + pass # ignore else: raise ValueError(f"Unknown service kind: {self._service_kind}") diff --git a/genai-perf/genai_perf/wrapper.py b/genai-perf/genai_perf/wrapper.py index fe9abdbb..c7b27a6b 100644 --- a/genai-perf/genai_perf/wrapper.py +++ b/genai-perf/genai_perf/wrapper.py @@ -110,6 +110,9 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s f"--input-data", f"{args.artifact_dir / DEFAULT_INPUT_DATA_JSON}", ] + cmd += Profiler.add_protocol_args(args) + cmd += Profiler.add_inference_load_args(args) + for arg, value in vars(args).items(): if arg in skip_args: pass @@ -122,13 +125,10 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s cmd += [f"-{arg}"] else: cmd += [f"--{arg}"] - - # (TPA-237) GAP needs to call PA using triton_c_api service kind. - # Currently, it just calls using triton service kind to verify that - # it runs. + # GAP needs to call PA using triton_c_api service kind when running + # against tensorrtllm engine. elif arg == "service_kind" and value == "tensorrtllm_engine": - cmd += ["--service-kind", "triton"] - args.service_kind = "triton" + cmd += ["--service-kind", "triton_c_api", "--streaming"] else: if len(arg) == 1: cmd += [f"-{arg}", f"{value}"] @@ -136,9 +136,6 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s arg = utils.convert_option_name(arg) cmd += [f"--{arg}", f"{value}"] - cmd += Profiler.add_protocol_args(args) - cmd += Profiler.add_inference_load_args(args) - if extra_args is not None: for arg in extra_args: cmd += [f"{arg}"] diff --git a/genai-perf/tests/test_llm_profile_data_parser.py b/genai-perf/tests/test_llm_profile_data_parser.py index d776a6a8..a264eba2 100644 --- a/genai-perf/tests/test_llm_profile_data_parser.py +++ b/genai-perf/tests/test_llm_profile_data_parser.py @@ -27,11 +27,12 @@ import json from io import StringIO from pathlib import Path -from typing import Any, List, Union +from typing import Any, List, Union, cast import numpy as np import pytest from genai_perf.metrics import LLMMetrics +from genai_perf.metrics.statistics import Statistics from genai_perf.profile_data_parser import LLMProfileDataParser from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer @@ -41,6 +42,28 @@ def ns_to_sec(ns: int) -> Union[int, float]: return ns / 1e9 +def check_statistics(s1: Statistics, s2: Statistics) -> None: + s1_dict = s1.stats_dict + s2_dict = s2.stats_dict + for metric in s1_dict.keys(): + for stat_name, value in s1_dict[metric].items(): + if stat_name != "unit": + assert s2_dict[metric][stat_name] == pytest.approx(value) + + +def check_llm_metrics(m1: LLMMetrics, m2: LLMMetrics) -> None: + assert m1.request_latencies == m2.request_latencies + assert m1.request_throughputs == pytest.approx(m2.request_throughputs) + assert m1.time_to_first_tokens == m2.time_to_first_tokens + assert m1.inter_token_latencies == m2.inter_token_latencies + assert m1.output_token_throughputs_per_request == pytest.approx( + m2.output_token_throughputs_per_request + ) + assert m1.output_token_throughputs == pytest.approx(m2.output_token_throughputs) + assert m1.output_sequence_lengths == m2.output_sequence_lengths + assert m1.input_sequence_lengths == m2.input_sequence_lengths + + class TestLLMProfileDataParser: @pytest.fixture def mock_read_write(self, monkeypatch: pytest.MonkeyPatch) -> List[str]: @@ -74,6 +97,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_vlm_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) return tmp_file + elif filename == "tensorrtllm_engine_profile_export.json": + tmp_file = StringIO(json.dumps(self.tensorrtllm_engine_profile_data)) + return tmp_file elif filename == "empty_profile_export.json": tmp_file = StringIO(json.dumps(self.empty_profile_data)) return tmp_file @@ -410,6 +436,103 @@ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="40") + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_latencies": [7, 9], + "request_throughputs": [1 / ns_to_sec(5)], + "time_to_first_tokens": [2, 2], + "inter_token_latencies": [2, 4], + "output_token_throughputs_per_request": [ + 3 / ns_to_sec(7), + 1 / ns_to_sec(3), + ], + "output_token_throughputs": [3 / ns_to_sec(5)], + "output_sequence_lengths": [3, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ( + "request_rate", + "2.0", + { + "request_latencies": [13, 8], + "request_throughputs": [2 / ns_to_sec(15)], + "time_to_first_tokens": [2, 3], + "inter_token_latencies": [4, 2], + "output_token_throughputs_per_request": [ + 4 / ns_to_sec(13), + 3 / ns_to_sec(8), + ], + "output_token_throughputs": [7 / ns_to_sec(15)], + "output_sequence_lengths": [4, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ], + ) + def test_tensorrtllm_engine_llm_profile_data( + self, + mock_read_write: pytest.MonkeyPatch, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * request_latencies + - experiment 1: [8 - 1, 11 - 2] = [7, 9] + - experiment 2: [18 - 5, 11 -3] = [13, 8] + * request_throughputs + - experiment 1: [2/(11 - 1)] = [1/5] + - experiment 2: [2/(18 - 3)] = [2/15] + * time to first tokens + - experiment 1: [3 - 1, 4 - 2] = [2, 2] + - experiment 2: [7 - 5, 6 - 3] = [2, 3] + * inter token latencies + - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(3 - 1)] + : [2.5, 3.5] + : [2, 4] # rounded + - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(3 - 1)] + : [11/3, 2.5] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(8 - 1), 3/(11 - 2)] = [3/7, 1/3] + - experiment 2: [4/(18 - 5), 3/(11 - 3)] = [4/13, 3/8] + * output token throughputs + - experiment 1: [(3 + 3)/(11 - 1)] = [3/5] + - experiment 2: [(4 + 3)/(18 - 3)] = [7/15] + * output sequence lengths + - experiment 1: [3, 3] + - experiment 2: [4, 3] + * input sequence lengths + - experiment 1: [3, 4] + - experiment 2: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("tensorrtllm_engine_profile_export.json"), + tokenizer=tokenizer, + ) + + statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level) + metrics = cast(LLMMetrics, statistics.metrics) + + expected_metrics = LLMMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_llm_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="30") + def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: """Test merging the multiple sse response.""" res_timestamps = [0, 1, 2, 3] @@ -740,3 +863,147 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: }, ], } + + tensorrtllm_engine_profile_data = { + "service_kind": "triton_c_api", + "endpoint": "", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 3, + "input_ids": [ + 111, + 222, + 333, + ], + }, + "response_timestamps": [3, 5, 8], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_ids": 789, + }, + ], + }, + { + "timestamp": 2, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 4, + "input_ids": [ + 111, + 222, + 333, + 444, + ], + }, + "response_timestamps": [4, 7, 11], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + ], + }, + ], + }, + { + "experiment": { + "mode": "request_rate", + "value": 2.0, + }, + "requests": [ + { + "timestamp": 5, + "request_inputs": { + "streaming": True, + "request_output_len": 4, + "min_length": 4, + "input_lengths": 3, + "input_ids": [ + 111, + 222, + 333, + ], + }, + "response_timestamps": [7, 8, 13, 18], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + { + "output_log_probs": [0, 0], + "output_ids": 1011, + }, + ], + }, + { + "timestamp": 3, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 4, + "input_ids": [ + 111, + 222, + 333, + 444, + ], + }, + "response_timestamps": [6, 8, 11], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + ], + }, + ], + }, + ], + }