diff --git a/common/config_models.py b/common/config_models.py index 286057ec..5e5b5a21 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -4,13 +4,13 @@ from common.utils import unwrap -class config_config_model(BaseModel): +class ConfigConfig(BaseModel): config: Optional[str] = Field( None, description=("Path to an overriding config.yml file") ) -class network_config_model(BaseModel): +class NetworkConfig(BaseModel): host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) port: Optional[int] = Field(5000, description=("The port to host on")) disable_auth: Optional[bool] = Field( @@ -28,7 +28,7 @@ class network_config_model(BaseModel): ) -class logging_config_model(BaseModel): +class LoggingConfig(BaseModel): log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging")) log_generation_params: Optional[bool] = Field( False, description=("Enable generation parameter logging") @@ -36,7 +36,7 @@ class logging_config_model(BaseModel): log_requests: Optional[bool] = Field(False, description=("Enable request logging")) -class model_config_model(BaseModel): +class ModelConfig(BaseModel): model_dir: str = Field( "models", description=( @@ -171,8 +171,10 @@ class model_config_model(BaseModel): ), ) + model_config = ConfigDict(protected_namespaces=()) -class draft_model_config_model(BaseModel): + +class DraftModelConfig(BaseModel): draft_model_dir: Optional[str] = Field( "models", description=( @@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel): ) -class lora_instance_model(BaseModel): +class LoraInstanceModel(BaseModel): name: str = Field(..., description=("Name of the LoRA model")) scaling: float = Field( 1.0, description=("Scaling factor for the LoRA model (default: 1.0)") ) -class lora_config_model(BaseModel): +class LoraConfig(BaseModel): lora_dir: Optional[str] = Field( "loras", description=("Directory to look for LoRAs (default: 'loras')") ) - loras: Optional[List[lora_instance_model]] = Field( + loras: Optional[List[LoraInstanceModel]] = Field( None, description=( "List of LoRAs to load and associated scaling factors (default scaling:" @@ -229,13 +231,13 @@ class lora_config_model(BaseModel): ) -class sampling_config_model(BaseModel): +class SamplingConfig(BaseModel): override_preset: Optional[str] = Field( None, description=("Select a sampler override preset") ) -class developer_config_model(BaseModel): +class DeveloperConfig(BaseModel): unsafe_launch: Optional[bool] = Field( False, description=("Skip Exllamav2 version check") ) @@ -257,7 +259,7 @@ class developer_config_model(BaseModel): ) -class embeddings_config_model(BaseModel): +class EmbeddingsConfig(BaseModel): embedding_model_dir: Optional[str] = Field( "models", description=( @@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel): ) -class tabby_config_model(BaseModel): - config: config_config_model = Field(default_factory=config_config_model) - network: network_config_model = Field(default_factory=network_config_model) - logging: logging_config_model = Field(default_factory=logging_config_model) - model: model_config_model = Field(default_factory=model_config_model) - draft_model: draft_model_config_model = Field( - default_factory=draft_model_config_model +class TabbyConfigModel(BaseModel): + config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct) + network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) + logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) + model: ModelConfig = Field(default_factory=ModelConfig.model_construct) + draft_model: DraftModelConfig = Field( + default_factory=DraftModelConfig.model_construct + ) + lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) + sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) + developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) + embeddings: EmbeddingsConfig = Field( + default_factory=EmbeddingsConfig.model_construct ) - lora: lora_config_model = Field(default_factory=lora_config_model) - sampling: sampling_config_model = Field(default_factory=sampling_config_model) - developer: developer_config_model = Field(default_factory=developer_config_model) - embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model) @model_validator(mode="before") def set_defaults(cls, values): @@ -297,11 +301,11 @@ def set_defaults(cls, values): values[field_name] = cls.__annotations__[field_name](**default_instance) return values - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) def generate_config_file(filename="config_sample.yml", indentation=2): - schema = tabby_config_model.model_json_schema() + schema = TabbyConfigModel.model_json_schema() def dump_def(id: str, indent=2): yaml = "" diff --git a/common/tabby_config.py b/common/tabby_config.py index 9738c120..fd952a20 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -5,11 +5,10 @@ from os import getenv from common.utils import unwrap, merge_dicts -from common.config_models import tabby_config_model -import common.config_models +from common.config_models import TabbyConfigModel -class TabbyConfig(tabby_config_model): +class TabbyConfig(TabbyConfigModel): # Persistent defaults # TODO: make this pydantic? model_defaults: dict = {} @@ -26,11 +25,11 @@ def load(self, arguments: Optional[dict] = None): merged_config = merge_dicts(*configs) - for field in tabby_config_model.model_fields.keys(): - value = unwrap(merged_config.get(field), {}) - model = getattr(common.config_models, f"{field}_config_model") - - setattr(self, field, model.parse_obj(value)) + # validate and update config + merged_config_model = TabbyConfigModel.model_validate(merged_config) + for field in TabbyConfigModel.model_fields.keys(): + value = getattr(merged_config_model, field) + setattr(self, field, value) # Set model defaults dict once to prevent on-demand reconstruction # TODO: clean this up a bit @@ -71,7 +70,7 @@ def _from_args(self, args: dict): config = self._from_file(pathlib.Path(config_override)) return config # Return early if loading from file - for key in tabby_config_model.model_fields.keys(): + for key in TabbyConfigModel.model_fields.keys(): override = args.get(key) if override: if key == "logging": @@ -86,10 +85,10 @@ def _from_environment(self): config = {} - for field_name in tabby_config_model.model_fields.keys(): + for field_name in TabbyConfigModel.model_fields.keys(): section_config = {} for sub_field_name in getattr( - tabby_config_model(), field_name + TabbyConfigModel(), field_name ).model_fields.keys(): setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None) if setting is not None: