Skip to content

Commit

Permalink
feat: add model validator for model config for env overrides
Browse files Browse the repository at this point in the history
- now GERD_MODEL_ can be used to override model config parameters
  • Loading branch information
aleneum committed Dec 1, 2024
1 parent f12e9d1 commit 507da9e
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions gerd/models/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from string import Formatter
from typing import Any, List, Literal, Mapping, Optional, Tuple, TypedDict
Expand All @@ -8,8 +9,8 @@
ConfigDict,
Field,
SecretStr,
ValidationError,
computed_field,
model_validator,
)

ChatRole = Literal["system", "user", "assistant"]
Expand Down Expand Up @@ -40,9 +41,11 @@ def format(
return (
self.template.render(**parameters)
if self.template
else self.text.format(**parameters)
if self.text
else "".join(str(parameters.values()))
else (
self.text.format(**parameters)
if self.text
else "".join(str(parameters.values()))
)
)

def model_post_init(self, __context: Any) -> None: # noqa: ANN401
Expand Down Expand Up @@ -128,3 +131,32 @@ class ModelConfig(BaseModel):
context_length: int = 0 # Currently only LLaMA, MPT and Falcon
gpu_layers: int = 0
torch_dtype: Optional[str] = None

@model_validator(mode="after")
@classmethod
def validate_field(cls, data: Any) -> Any: # noqa: ANN401
for field in cls.model_fields:
env_name = f"GERD_MODEL_{field.upper()}"
# Special handling of endpoint field override
if (
field == "endpoint"
and f"{env_name}_URL" in os.environ
and f"{env_name}_TYPE" in os.environ
):
setattr(
data,
field,
ModelEndpoint(
url=os.environ[f"{env_name}_URL"],
type=os.environ[f"{env_name}_TYPE"],
key=(
SecretStr(os.environ.get(f"{env_name}_KEY"))
if f"{env_name}_KEY" in os.environ
else None
),
),
)
elif env_val := os.environ.get(env_name):
setattr(data, field, type(getattr(data, field))(env_val))

return data

0 comments on commit 507da9e

Please sign in to comment.