Skip to content

Commit

Permalink
Tree: Add transformers_utils
Browse files Browse the repository at this point in the history
Part of commit 8824ea0

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 20, 2024
1 parent 8824ea0 commit 67f0618
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions common/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
import pathlib
from typing import List, Optional, Union
from pydantic import BaseModel


class GenerationConfig(BaseModel):
"""
An abridged version of HuggingFace's GenerationConfig.
Will be expanded as needed.
"""

eos_token_id: Optional[Union[int, List[int]]] = None

@classmethod
def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""

generation_config_path = model_directory / "generation_config.json"
with open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
return self.model_validate(generation_config_dict)

def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""

if isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return self.eos_token_id

0 comments on commit 67f0618

Please sign in to comment.