From 5e8ff9a00435c76a5c742c590b04c1ad7a76706e Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 20:52:29 -0400 Subject: [PATCH] Tree: Fix classmethod usage Instead of self, use cls which passes a type of the class. Signed-off-by: kingbri --- common/templating.py | 8 ++++---- common/transformers_utils.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/common/templating.py b/common/templating.py index 2c0e5e2f..1200e5ca 100644 --- a/common/templating.py +++ b/common/templating.py @@ -111,7 +111,7 @@ def __init__(self, name: str, raw_template: str): self.template = self.compile(raw_template) @classmethod - async def from_file(self, template_path: pathlib.Path): + async def from_file(cls, template_path: pathlib.Path): """Get a template from a jinja file.""" # Add the jinja extension if it isn't provided @@ -126,7 +126,7 @@ async def from_file(self, template_path: pathlib.Path): template_path, "r", encoding="utf8" ) as raw_template_stream: contents = await raw_template_stream.read() - return PromptTemplate( + return cls( name=template_name, raw_template=contents, ) @@ -138,7 +138,7 @@ async def from_file(self, template_path: pathlib.Path): @classmethod async def from_model_json( - self, json_path: pathlib.Path, key: str, name: Optional[str] = None + cls, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): @@ -177,7 +177,7 @@ async def from_model_json( ) else: # Can safely assume the chat template is the old style - return PromptTemplate( + return cls( name="from_tokenizer_config", raw_template=chat_template, ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 386f5430..4fd848d3 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -16,7 +16,7 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - async def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" @@ -24,7 +24,7 @@ async def from_file(self, model_directory: pathlib.Path): generation_config_path, "r", encoding="utf8" ) as generation_config_json: generation_config_dict = json.load(generation_config_json) - return self.model_validate(generation_config_dict) + return cls.model_validate(generation_config_dict) def eos_tokens(self): """Wrapper method to fetch EOS tokens.""" @@ -44,7 +44,7 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - async def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" @@ -53,7 +53,7 @@ async def from_file(self, model_directory: pathlib.Path): ) as hf_config_json: contents = await hf_config_json.read() hf_config_dict = json.loads(contents) - return self.model_validate(hf_config_dict) + return cls.model_validate(hf_config_dict) def get_badwordsids(self): """Wrapper method to fetch badwordsids."""