Skip to content

Commit

Permalink
Support payload input for timing strategy (#209)
Browse files Browse the repository at this point in the history
Replace with correct variable
  • Loading branch information
lkomali authored and debermudez committed Jan 3, 2025
1 parent 8205e12 commit 2f333df
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
1 change: 1 addition & 0 deletions genai-perf/genai_perf/inputs/input_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ModelSelectionStrategy(Enum):
class PromptSource(Enum):
SYNTHETIC = auto()
FILE = auto()
PAYLOAD = auto()


class OutputFormat(Enum):
Expand Down
19 changes: 16 additions & 3 deletions genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,28 @@ def parse_goodput(values):

def _infer_prompt_source(args: argparse.Namespace) -> argparse.Namespace:
args.synthetic_input_files = None
args.payload_input_file = None

if args.input_file:
if str(args.input_file).startswith("synthetic:"):
input_file_str = str(args.input_file)
if input_file_str.startswith("synthetic:"):
args.prompt_source = ic.PromptSource.SYNTHETIC
synthetic_input_files_str = str(args.input_file).split(":", 1)[1]
synthetic_input_files_str = input_file_str.split(":", 1)[1]
args.synthetic_input_files = synthetic_input_files_str.split(",")
logger.debug(
f"Input source is synthetic data: {args.synthetic_input_files}"
)
elif input_file_str.startswith("payload:"):
args.prompt_source = ic.PromptSource.PAYLOAD
payload_input_file_str = input_file_str.split(":", 1)[1]
if not payload_input_file_str:
raise ValueError(
f"Invalid payload input: '{input_file_str}' is missing the file path"
)
args.payload_input_file = payload_input_file_str.split(",")
logger.debug(
f"Input source is a payload file with timing information in the following path: {args.payload_input_file}"
)
else:
args.prompt_source = ic.PromptSource.FILE
logger.debug(f"Input source is the following path: {args.input_file}")
Expand All @@ -496,7 +509,7 @@ def _convert_str_to_enum_entry(args, option, enum):


def file_or_directory(value: str) -> Path:
if value.startswith("synthetic:"):
if value.startswith("synthetic:") or value.startswith("payload:"):
return Path(value)
else:
path = Path(value)
Expand Down
1 change: 1 addition & 0 deletions genai-perf/genai_perf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
"output_tokens_mean",
"output_tokens_mean_deterministic",
"output_tokens_stddev",
"payload_input_file",
"prompt_source",
"random_seed",
"request_rate",
Expand Down
42 changes: 36 additions & 6 deletions genai-perf/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,25 +844,55 @@ def test_goodput_args_warning(self, monkeypatch, args, expected_error):
assert str(exc_info.value) == expected_error

@pytest.mark.parametrize(
"args, expected_prompt_source",
"args, expected_prompt_source, expected_payload_input_file, expect_error",
[
([], PromptSource.SYNTHETIC),
(["--input-file", "prompt.txt"], PromptSource.FILE),
([], PromptSource.SYNTHETIC, None, False),
(["--input-file", "prompt.txt"], PromptSource.FILE, None, False),
(
["--input-file", "prompt.txt", "--synthetic-input-tokens-mean", "10"],
PromptSource.FILE,
None,
False,
),
(
["--input-file", "payload:test.jsonl"],
PromptSource.PAYLOAD,
["test.jsonl"],
False,
),
(["--input-file", "payload:"], PromptSource.PAYLOAD, [], True),
(
["--input-file", "synthetic:test.jsonl"],
PromptSource.SYNTHETIC,
None,
False,
),
(["--input-file", "invalidinput"], PromptSource.FILE, None, False),
],
)
def test_inferred_prompt_source(
self, monkeypatch, mocker, args, expected_prompt_source
self,
monkeypatch,
mocker,
args,
expected_prompt_source,
expected_payload_input_file,
expect_error,
):
mocker.patch.object(Path, "is_file", return_value=True)
combined_args = ["genai-perf", "profile", "--model", "test_model"] + args
monkeypatch.setattr("sys.argv", combined_args)
args, _ = parser.parse_args()

assert args.prompt_source == expected_prompt_source
if expect_error:
with pytest.raises(ValueError):
parser.parse_args()
else:
args, _ = parser.parse_args()

assert args.prompt_source == expected_prompt_source

if expected_payload_input_file is not None:
assert args.payload_input_file == expected_payload_input_file

@pytest.mark.parametrize(
"args",
Expand Down
1 change: 1 addition & 0 deletions genai-perf/tests/test_exporters/test_json_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class TestJsonExporter:
"output_tokens_mean": -1,
"output_tokens_mean_deterministic": false,
"output_tokens_stddev": 0,
"payload_input_file": null,
"random_seed": 0,
"request_count": 0,
"synthetic_input_files": null,
Expand Down

0 comments on commit 2f333df

Please sign in to comment.