Skip to content

Commit

Permalink
Make tokenizer not optional
Browse files Browse the repository at this point in the history
  • Loading branch information
dyastremsky committed Nov 4, 2024
1 parent 49fee42 commit 91cd63b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import random
from typing import Any, Dict, cast
from typing import Any, Dict

from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.converters.base_converter import BaseConverter
Expand All @@ -36,7 +36,6 @@
)
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.generic_dataset import GenericDataset
from genai_perf.tokenizer import Tokenizer


class TensorRTLLMEngineConverter(BaseConverter):
Expand All @@ -53,7 +52,7 @@ def convert(

for file_data in generic_dataset.files_data.values():
for row in file_data.rows:
token_ids = cast(Tokenizer, config.tokenizer).encode(row.texts[0])
token_ids = config.tokenizer.encode(row.texts[0])
payload = {
"input_ids": {
"content": token_ids,
Expand Down
6 changes: 3 additions & 3 deletions genai-perf/genai_perf/inputs/inputs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class InputsConfig:
# General Parameters
####################

# The tokenizer to use when generating synthetic prompts
tokenizer: Tokenizer

# If true, adds a steam field to each payload
add_stream: bool = False

Expand Down Expand Up @@ -139,6 +142,3 @@ class InputsConfig:

# Seed used to generate random values
random_seed: int = DEFAULT_RANDOM_SEED

# The tokenizer to use when generating synthetic prompts
tokenizer: Optional[Tokenizer] = None
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


from typing import List, cast
from typing import List

from genai_perf.inputs.input_constants import DEFAULT_SYNTHETIC_FILENAME
from genai_perf.inputs.retrievers.base_input_retriever import BaseInputRetriever
Expand All @@ -40,7 +40,6 @@
from genai_perf.inputs.retrievers.synthetic_prompt_generator import (
SyntheticPromptGenerator,
)
from genai_perf.tokenizer import Tokenizer


class SyntheticDataRetriever(BaseInputRetriever):
Expand All @@ -58,7 +57,7 @@ def retrieve_data(self) -> GenericDataset:
for _ in range(self.config.num_prompts):
row = DataRow(texts=[], images=[])
prompt = SyntheticPromptGenerator.create_synthetic_prompt(
cast(Tokenizer, self.config.tokenizer),
self.config.tokenizer,
self.config.prompt_tokens_mean,
self.config.prompt_tokens_stddev,
)
Expand Down

0 comments on commit 91cd63b

Please sign in to comment.