diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0c67388..15c51c1 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -56,8 +56,6 @@ from common.transformers_utils import GenerationConfig, HuggingFaceConfig from common.utils import coalesce, unwrap -yaml = YAML() - class ExllamaV2Container: """The model container class for ExLlamaV2 models.""" @@ -381,7 +379,10 @@ async def set_model_overrides(self, **kwargs): override_config_path, "r", encoding="utf8" ) as override_config_file: contents = await override_config_file.read() - override_args = unwrap(yaml.safe_load(contents), {}) + + # Create a temporary YAML parser + yaml = YAML(typ="safe") + override_args = unwrap(yaml.load(contents), {}) # Merge draft overrides beforehand draft_override_args = unwrap(override_args.get("draft"), {}) diff --git a/common/auth.py b/common/auth.py index c822b88..67e393b 100644 --- a/common/auth.py +++ b/common/auth.py @@ -4,6 +4,7 @@ """ import aiofiles +import io import secrets from ruamel.yaml import YAML from fastapi import Header, HTTPException, Request @@ -13,8 +14,6 @@ from common.utils import coalesce -yaml = YAML() - class AuthKeys(BaseModel): """ @@ -59,6 +58,9 @@ async def load_auth_keys(disable_from_config: bool): return + # Create a temporary YAML parser + yaml = YAML(typ=["rt", "safe"]) + try: async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: contents = await auth_file.read() @@ -71,10 +73,12 @@ async def load_auth_keys(disable_from_config: bool): AUTH_KEYS = new_auth_keys async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: - new_auth_yaml = yaml.safe_dump( - AUTH_KEYS.model_dump(), default_flow_style=False + string_stream = io.StringIO() + yaml.dump( + AUTH_KEYS.model_dump(), string_stream ) - await auth_file.write(new_auth_yaml) + + await auth_file.write(string_stream.getvalue()) logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" diff --git a/common/sampling.py b/common/sampling.py index 7005794..e49811d 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -11,8 +11,6 @@ from common.utils import unwrap, prune_dict -yaml = YAML() - # Common class for sampler params class BaseSamplerRequest(BaseModel): @@ -418,7 +416,10 @@ async def overrides_from_file(preset_name: str): overrides_container.selected_preset = preset_path.stem async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset: contents = await raw_preset.read() - preset = yaml.safe_load(contents) + + # Create a temporary YAML parser + yaml = YAML(typ="safe") + preset = yaml.load(contents) overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") diff --git a/common/tabby_config.py b/common/tabby_config.py index 43f227e..219e180 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -13,7 +13,7 @@ from common.config_models import BaseConfigModel, TabbyConfigModel from common.utils import merge_dicts, unwrap -yaml = YAML() +yaml = YAML(typ=["rt", "safe"]) class TabbyConfig(TabbyConfigModel):