Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Added model_max_length as a tok parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
aittalam committed Aug 5, 2024
1 parent 4bcc24c commit 16b447a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/lm_buddy/configs/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AutoTokenizerConfig(LMBuddyConfig):
path: AssetPath
trust_remote_code: bool | None = None
use_fast: bool | None = None
mod_max_length: int | None = None


class DatasetConfig(LMBuddyConfig):
Expand Down
1 change: 1 addition & 0 deletions src/lm_buddy/jobs/asset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def load_pretrained_tokenizer(self, config: AutoTokenizerConfig) -> PreTrainedTo
pretrained_model_name_or_path=tokenizer_path,
trust_remote_code=config.trust_remote_code,
use_fast=config.use_fast,
model_max_length=config.mod_max_length,
)
if tokenizer.pad_token_id is None:
# Pad token required for generating consistent batch sizes
Expand Down
2 changes: 1 addition & 1 deletion src/lm_buddy/jobs/evaluation/hf_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def run_eval(config: HuggingFaceEvalJobConfig) -> Path:
# for inference
if config.evaluation.use_pipeline:
logger.info(f"Using summarization pipeline. Model: {model_name}")
model_client = SummarizationPipelineModelClient(model_name, config.model)
model_client = SummarizationPipelineModelClient(model_name, config)
else:
logger.info(f"Using direct HF model invocation. Model: {model_name}")
model_client = HuggingFaceModelClient(model_name, config)
Expand Down
8 changes: 7 additions & 1 deletion src/lm_buddy/jobs/model_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ class SummarizationPipelineModelClient(BaseModelClient):
(model is loaded locally).
"""

def __init__(self, model: str, config: AutoModelConfig):
def __init__(self, model: str, config: HuggingFaceEvalJobConfig):
self._config = config

hf_tokenizer_loader = HuggingFaceTokenizerLoader()
self._tokenizer = hf_tokenizer_loader.load_pretrained_tokenizer(config.tokenizer)

self._summarizer = pipeline(
"summarization",
model=model,
tokenizer=self._tokenizer,
device=0 if torch.cuda.is_available() else -1,
truncation=True,
)
Expand Down

0 comments on commit 16b447a

Please sign in to comment.