Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lkomali committed Jan 14, 2025
1 parent 52020cc commit 482a3b7
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 855 deletions.
11 changes: 7 additions & 4 deletions genai-perf/genai_perf/export_data/json_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ def export(self) -> None:
0
]
filename = self._output_dir / f"{prefix}_genai_perf.json"
logger.info(f"Generating {filename}")
with open(str(filename), "w") as f:
f.write(json.dumps(self._stats_and_args, indent=2))

def _exclude_args(self, args_to_exclude) -> None:
for arg in args_to_exclude:
self._args.pop(arg, None)

def _prepare_args_for_export(self) -> None:
self._args.pop("func", None)
self._args.pop("output_format", None)
self._args.pop("input_file", None)
self._args.pop("payload_input_file", None)
args_to_exclude = ["func", "output_format", "input_file", "payload_input_file"]
self._exclude_args(args_to_exclude)
self._args["profile_export_file"] = str(self._args["profile_export_file"])
self._args["artifact_dir"] = str(self._args["artifact_dir"])
for k, v in self._args.items():
Expand Down
5 changes: 2 additions & 3 deletions genai-perf/genai_perf/inputs/converters/base_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,5 @@ def _add_request_params(
def _add_payload_params(
self, payload: Dict[Any, Any], optional_data: Dict[Any, Any]
) -> None:
if optional_data:
for key, value in optional_data.items():
payload[key] = value
for key, value in optional_data.items():
payload[key] = value
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ class BaseFileInputRetriever(BaseInputRetriever):
A base input retriever class that defines file input methods.
"""

def retrieve_data(self) -> GenericDataset:
"""
Retrieves the dataset from a file or directory.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def _get_input_dataset_from_file(self, filename: Path) -> FileData:
"""
Retrieves the dataset from a specific JSONL file.
"""

raise NotImplementedError("This method should be implemented by subclasses.")

def _verify_file(self, filename: Path) -> None:
"""
Verifies that the file exists.
Expand All @@ -77,3 +63,17 @@ def _get_content_from_input_file(self, filename: Path) -> Union[
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def _get_input_dataset_from_file(self, filename: Path) -> FileData:
"""
Retrieves the dataset from a specific JSONL file.
"""

raise NotImplementedError("This method should be implemented by subclasses.")

def retrieve_data(self) -> GenericDataset:
"""
Retrieves the dataset from a file or directory.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
31 changes: 16 additions & 15 deletions genai-perf/genai_perf/inputs/retrievers/payload_input_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ def _check_for_optional_data(self, data: Dict[str, Any]) -> Dict[Any, Any]:
"""
Checks if there is any optional data in the file to pass in the payload.
"""
optional_data = {}
for k, v in data.items():
if k not in ["text", "text_input", "timestamp"]:
optional_data[k] = v
optional_data = {
k: v
for k, v in data.items()
if k not in ["text", "text_input", "timestamp"]
}
return optional_data

def _convert_content_to_data_file(
Expand All @@ -154,6 +155,8 @@ def _convert_content_to_data_file(
----------
prompts : List[str]
The list of prompts to convert.
timestamps: str
The timestamp at which the request should be sent.
optional_data : Dict
The optional data included in every payload.
Expand All @@ -162,15 +165,13 @@ def _convert_content_to_data_file(
FileData
The DataFile containing the converted data.
"""
data_rows: List[DataRow] = []

if prompts:
for index, prompt in enumerate(prompts):
data_rows.append(
DataRow(
texts=[prompt],
timestamp=timestamps[index],
optional_data=optional_datas[index],
)
)
data_rows: List[DataRow] = [
DataRow(
texts=[prompt],
timestamp=timestamps[index],
optional_data=optional_datas[index],
)
for index, prompt in enumerate(prompts)
]

return FileData(data_rows)
16 changes: 16 additions & 0 deletions genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,21 @@ class EndpointConfig:
}


def _check_payload_input_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
"""
Raise an error if concurrency or request-range is set
"""

if args.prompt_source == ic.PromptSource.PAYLOAD:
if args.concurrency or args.request_rate:
raise ValueError(
"Concurrency and request_rate cannot be used with payload input."
)
return args


def _check_model_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
Expand Down Expand Up @@ -1104,6 +1119,7 @@ def refine_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
if args.subcommand == Subcommand.PROFILE.to_lowercase():
args = _check_payload_input_args(parser, args)
args = _infer_prompt_source(args)
args = _check_model_args(parser, args)
args = _check_conditional_args(parser, args)
Expand Down
2 changes: 0 additions & 2 deletions genai-perf/genai_perf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def add_protocol_args(args: Namespace) -> List[str]:
@staticmethod
def add_inference_load_args(args: Namespace) -> List[str]:
cmd: list[str] = []
if args.prompt_source == PromptSource.PAYLOAD:
return cmd
if args.concurrency:
cmd += ["--concurrency-range", f"{args.concurrency}"]
elif args.request_rate:
Expand Down
Loading

0 comments on commit 482a3b7

Please sign in to comment.