From 054f0e576b7820f137819fb3dad59dfccd3e9e2b Mon Sep 17 00:00:00 2001 From: anton Date: Fri, 20 Dec 2024 17:03:16 +0100 Subject: [PATCH] tokenizer param + misc fixes --- src/lighteval/models/vllm/vllm_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 2d413807..3ee20e68 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -36,7 +36,7 @@ GenerativeResponse, LoglikelihoodResponse, ) -from lighteval.models.utils import _get_dtype, _simplify_name +from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, @@ -89,6 +89,9 @@ class VLLMModelConfig: subfolder: Optional[str] = None temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0. + def get_model_sha(self): + return _get_model_sha(repo_id=self.pretrained, revision=self.revision) + class VLLMModel(LightevalModel): def __init__( @@ -113,10 +116,10 @@ def __init__( self.multichoice_continuations_start_space = config.multichoice_continuations_start_space self.model_name = _simplify_name(config.pretrained) - self.model_sha = "" # config.get_model_sha() + self.model_sha = config.get_model_sha() self.precision = _get_dtype(config.dtype, config=self._config) - self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha) + self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha, model_dtype=config.dtype) self.pairwise_tokenization = config.pairwise_tokenization @property @@ -191,7 +194,7 @@ def _create_auto_tokenizer(self, config: VLLMModelConfig, env_config: EnvConfig) config.pretrained, tokenizer_mode="auto", trust_remote_code=config.trust_remote_code, - tokenizer_revision=config.revision, + revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""), ) tokenizer.pad_token = tokenizer.eos_token return tokenizer