Skip to content

Commit

Permalink
Support profile data parsing for tensorrtllm engine service kind (#33)
Browse files Browse the repository at this point in the history
* support parsing tensorrtllm engine profile response

* add test

* refactor the test

* update types and names

* fix pre-commit

* run PA with triton c api

* more clean up on the tests

* fix codeql

* address feedback
  • Loading branch information
nv-hwoo authored Aug 9, 2024
1 parent e258b28 commit 5f288a0
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 11 deletions.
2 changes: 1 addition & 1 deletion genai-perf/genai_perf/metrics/llm_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
time_to_first_tokens: List[int] = [],
inter_token_latencies: List[int] = [],
output_token_throughputs: List[float] = [],
output_token_throughputs_per_request: List[int] = [],
output_token_throughputs_per_request: List[float] = [],
output_sequence_lengths: List[int] = [],
input_sequence_lengths: List[int] = [],
chunked_inter_token_latencies: List[List[int]] = [[]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def _get_input_token_count(self, req_inputs: dict) -> int:
"""Deserialize the request input and return tokenized inputs."""
if self._service_kind == "triton":
input_text = req_inputs["text_input"]
elif self._service_kind == "triton_c_api":
return len(req_inputs["input_ids"]) # no tokenizer required
elif self._service_kind == "openai":
input_text = self._get_openai_input_text(req_inputs)
else:
Expand Down Expand Up @@ -232,6 +234,9 @@ def _get_output_token_counts(
"""Return response-level token counts and total token count."""
if self._service_kind == "triton":
output_texts = self._get_triton_output_tokens(res_outputs)
elif self._service_kind == "triton_c_api":
# No tokenizer is need to get the token counts.
return self._get_tensorrtllm_engine_token_counts(res_outputs)
elif self._service_kind == "openai":
output_texts = self._get_openai_output_tokens(res_outputs)
else:
Expand All @@ -243,6 +248,17 @@ def _get_output_token_counts(
output_token_counts = list(map(len, output_tokens))
return output_token_counts, full_text_token_count

def _get_tensorrtllm_engine_token_counts(
self, res_outputs: List[Dict]
) -> Tuple[List[int], int]:
token_ids = []
for r in res_outputs:
if isinstance(r["output_ids"], list):
token_ids += r["output_ids"]
else:
token_ids.append(r["output_ids"])
return token_ids, len(token_ids)

def _get_triton_output_tokens(self, res_outputs: List[Dict]) -> List[str]:
"""Return a list of Triton response texts."""
return [r["text_output"] for r in res_outputs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def _get_profile_metadata(self, data: dict) -> None:

elif self._service_kind == "triton":
self._response_format = ResponseFormat.TRITON
elif self._service_kind == "triton_c_api":
pass # ignore
else:
raise ValueError(f"Unknown service kind: {self._service_kind}")

Expand Down
15 changes: 6 additions & 9 deletions genai-perf/genai_perf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
f"--input-data",
f"{args.artifact_dir / DEFAULT_INPUT_DATA_JSON}",
]
cmd += Profiler.add_protocol_args(args)
cmd += Profiler.add_inference_load_args(args)

for arg, value in vars(args).items():
if arg in skip_args:
pass
Expand All @@ -122,23 +125,17 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
cmd += [f"-{arg}"]
else:
cmd += [f"--{arg}"]

# (TPA-237) GAP needs to call PA using triton_c_api service kind.
# Currently, it just calls using triton service kind to verify that
# it runs.
# GAP needs to call PA using triton_c_api service kind when running
# against tensorrtllm engine.
elif arg == "service_kind" and value == "tensorrtllm_engine":
cmd += ["--service-kind", "triton"]
args.service_kind = "triton"
cmd += ["--service-kind", "triton_c_api", "--streaming"]
else:
if len(arg) == 1:
cmd += [f"-{arg}", f"{value}"]
else:
arg = utils.convert_option_name(arg)
cmd += [f"--{arg}", f"{value}"]

cmd += Profiler.add_protocol_args(args)
cmd += Profiler.add_inference_load_args(args)

if extra_args is not None:
for arg in extra_args:
cmd += [f"{arg}"]
Expand Down
269 changes: 268 additions & 1 deletion genai-perf/tests/test_llm_profile_data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
import json
from io import StringIO
from pathlib import Path
from typing import Any, List, Union
from typing import Any, List, Union, cast

import numpy as np
import pytest
from genai_perf.metrics import LLMMetrics
from genai_perf.metrics.statistics import Statistics
from genai_perf.profile_data_parser import LLMProfileDataParser
from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer

Expand All @@ -41,6 +42,28 @@ def ns_to_sec(ns: int) -> Union[int, float]:
return ns / 1e9


def check_statistics(s1: Statistics, s2: Statistics) -> None:
s1_dict = s1.stats_dict
s2_dict = s2.stats_dict
for metric in s1_dict.keys():
for stat_name, value in s1_dict[metric].items():
if stat_name != "unit":
assert s2_dict[metric][stat_name] == pytest.approx(value)


def check_llm_metrics(m1: LLMMetrics, m2: LLMMetrics) -> None:
assert m1.request_latencies == m2.request_latencies
assert m1.request_throughputs == pytest.approx(m2.request_throughputs)
assert m1.time_to_first_tokens == m2.time_to_first_tokens
assert m1.inter_token_latencies == m2.inter_token_latencies
assert m1.output_token_throughputs_per_request == pytest.approx(
m2.output_token_throughputs_per_request
)
assert m1.output_token_throughputs == pytest.approx(m2.output_token_throughputs)
assert m1.output_sequence_lengths == m2.output_sequence_lengths
assert m1.input_sequence_lengths == m2.input_sequence_lengths


class TestLLMProfileDataParser:
@pytest.fixture
def mock_read_write(self, monkeypatch: pytest.MonkeyPatch) -> List[str]:
Expand Down Expand Up @@ -74,6 +97,9 @@ def write(self: Any, content: str) -> int:
elif filename == "openai_vlm_profile_export.json":
tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data))
return tmp_file
elif filename == "tensorrtllm_engine_profile_export.json":
tmp_file = StringIO(json.dumps(self.tensorrtllm_engine_profile_data))
return tmp_file
elif filename == "empty_profile_export.json":
tmp_file = StringIO(json.dumps(self.empty_profile_data))
return tmp_file
Expand Down Expand Up @@ -410,6 +436,103 @@ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N
with pytest.raises(KeyError):
pd.get_statistics(infer_mode="concurrency", load_level="40")

@pytest.mark.parametrize(
"infer_mode, load_level, expected_metrics",
[
(
"concurrency",
"10",
{
"request_latencies": [7, 9],
"request_throughputs": [1 / ns_to_sec(5)],
"time_to_first_tokens": [2, 2],
"inter_token_latencies": [2, 4],
"output_token_throughputs_per_request": [
3 / ns_to_sec(7),
1 / ns_to_sec(3),
],
"output_token_throughputs": [3 / ns_to_sec(5)],
"output_sequence_lengths": [3, 3],
"input_sequence_lengths": [3, 4],
},
),
(
"request_rate",
"2.0",
{
"request_latencies": [13, 8],
"request_throughputs": [2 / ns_to_sec(15)],
"time_to_first_tokens": [2, 3],
"inter_token_latencies": [4, 2],
"output_token_throughputs_per_request": [
4 / ns_to_sec(13),
3 / ns_to_sec(8),
],
"output_token_throughputs": [7 / ns_to_sec(15)],
"output_sequence_lengths": [4, 3],
"input_sequence_lengths": [3, 4],
},
),
],
)
def test_tensorrtllm_engine_llm_profile_data(
self,
mock_read_write: pytest.MonkeyPatch,
infer_mode,
load_level,
expected_metrics,
) -> None:
"""Collect LLM metrics from profile export data and check values.
Metrics
* request_latencies
- experiment 1: [8 - 1, 11 - 2] = [7, 9]
- experiment 2: [18 - 5, 11 -3] = [13, 8]
* request_throughputs
- experiment 1: [2/(11 - 1)] = [1/5]
- experiment 2: [2/(18 - 3)] = [2/15]
* time to first tokens
- experiment 1: [3 - 1, 4 - 2] = [2, 2]
- experiment 2: [7 - 5, 6 - 3] = [2, 3]
* inter token latencies
- experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(3 - 1)]
: [2.5, 3.5]
: [2, 4] # rounded
- experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(3 - 1)]
: [11/3, 2.5]
: [4, 2] # rounded
* output token throughputs per request
- experiment 1: [3/(8 - 1), 3/(11 - 2)] = [3/7, 1/3]
- experiment 2: [4/(18 - 5), 3/(11 - 3)] = [4/13, 3/8]
* output token throughputs
- experiment 1: [(3 + 3)/(11 - 1)] = [3/5]
- experiment 2: [(4 + 3)/(18 - 3)] = [7/15]
* output sequence lengths
- experiment 1: [3, 3]
- experiment 2: [4, 3]
* input sequence lengths
- experiment 1: [3, 4]
- experiment 2: [3, 4]
"""
tokenizer = get_tokenizer(DEFAULT_TOKENIZER)
pd = LLMProfileDataParser(
filename=Path("tensorrtllm_engine_profile_export.json"),
tokenizer=tokenizer,
)

statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level)
metrics = cast(LLMMetrics, statistics.metrics)

expected_metrics = LLMMetrics(**expected_metrics)
expected_statistics = Statistics(expected_metrics)

check_llm_metrics(metrics, expected_metrics)
check_statistics(statistics, expected_statistics)

# check non-existing profile data
with pytest.raises(KeyError):
pd.get_statistics(infer_mode="concurrency", load_level="30")

def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None:
"""Test merging the multiple sse response."""
res_timestamps = [0, 1, 2, 3]
Expand Down Expand Up @@ -740,3 +863,147 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None:
},
],
}

tensorrtllm_engine_profile_data = {
"service_kind": "triton_c_api",
"endpoint": "",
"experiments": [
{
"experiment": {
"mode": "concurrency",
"value": 10,
},
"requests": [
{
"timestamp": 1,
"request_inputs": {
"streaming": True,
"request_output_len": 3,
"min_length": 3,
"input_lengths": 3,
"input_ids": [
111,
222,
333,
],
},
"response_timestamps": [3, 5, 8],
"response_outputs": [
{
"output_log_probs": [0, 0],
"output_ids": 123,
},
{
"output_log_probs": [0, 0],
"output_ids": 456,
},
{
"output_ids": 789,
},
],
},
{
"timestamp": 2,
"request_inputs": {
"streaming": True,
"request_output_len": 3,
"min_length": 3,
"input_lengths": 4,
"input_ids": [
111,
222,
333,
444,
],
},
"response_timestamps": [4, 7, 11],
"response_outputs": [
{
"output_log_probs": [0, 0],
"output_ids": 123,
},
{
"output_log_probs": [0, 0],
"output_ids": 456,
},
{
"output_log_probs": [0, 0],
"output_ids": 789,
},
],
},
],
},
{
"experiment": {
"mode": "request_rate",
"value": 2.0,
},
"requests": [
{
"timestamp": 5,
"request_inputs": {
"streaming": True,
"request_output_len": 4,
"min_length": 4,
"input_lengths": 3,
"input_ids": [
111,
222,
333,
],
},
"response_timestamps": [7, 8, 13, 18],
"response_outputs": [
{
"output_log_probs": [0, 0],
"output_ids": 123,
},
{
"output_log_probs": [0, 0],
"output_ids": 456,
},
{
"output_log_probs": [0, 0],
"output_ids": 789,
},
{
"output_log_probs": [0, 0],
"output_ids": 1011,
},
],
},
{
"timestamp": 3,
"request_inputs": {
"streaming": True,
"request_output_len": 3,
"min_length": 3,
"input_lengths": 4,
"input_ids": [
111,
222,
333,
444,
],
},
"response_timestamps": [6, 8, 11],
"response_outputs": [
{
"output_log_probs": [0, 0],
"output_ids": 123,
},
{
"output_log_probs": [0, 0],
"output_ids": 456,
},
{
"output_log_probs": [0, 0],
"output_ids": 789,
},
],
},
],
},
],
}

0 comments on commit 5f288a0

Please sign in to comment.