From 67f061859d46e8833a152ec532578fcc4a464c39 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 20 Apr 2024 00:07:39 -0400 Subject: [PATCH] Tree: Add transformers_utils Part of commit 8824ea0205cb93e7e8226474abb053f2613aea4f Signed-off-by: kingbri --- common/transformers_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 common/transformers_utils.py diff --git a/common/transformers_utils.py b/common/transformers_utils.py new file mode 100644 index 00000000..62d46223 --- /dev/null +++ b/common/transformers_utils.py @@ -0,0 +1,32 @@ +import json +import pathlib +from typing import List, Optional, Union +from pydantic import BaseModel + + +class GenerationConfig(BaseModel): + """ + An abridged version of HuggingFace's GenerationConfig. + Will be expanded as needed. + """ + + eos_token_id: Optional[Union[int, List[int]]] = None + + @classmethod + def from_file(self, model_directory: pathlib.Path): + """Create an instance from a generation config file.""" + + generation_config_path = model_directory / "generation_config.json" + with open( + generation_config_path, "r", encoding="utf8" + ) as generation_config_json: + generation_config_dict = json.load(generation_config_json) + return self.model_validate(generation_config_dict) + + def eos_tokens(self): + """Wrapper method to fetch EOS tokens.""" + + if isinstance(self.eos_token_id, int): + return [self.eos_token_id] + else: + return self.eos_token_id