From 6194a1013f161dbef5f5aad67ed20d3ac2b12e57 Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Tue, 9 Jul 2024 12:04:28 +0800 Subject: [PATCH] minor update for ai-service (#493) * remove trailing / in urls of env vars * update * fix * fix bug * simplify codebase --- wren-ai-service/src/force_deploy.py | 9 ++++---- .../src/providers/embedder/azure_openai.py | 15 ++++++++----- .../src/providers/embedder/ollama.py | 9 ++++---- .../src/providers/embedder/openai.py | 22 +++++++++++-------- .../src/providers/llm/azure_openai.py | 15 ++++++++----- wren-ai-service/src/providers/llm/ollama.py | 8 ++++--- wren-ai-service/src/providers/llm/openai.py | 18 ++++++++------- wren-ai-service/src/utils.py | 4 ++++ 8 files changed, 60 insertions(+), 40 deletions(-) diff --git a/wren-ai-service/src/force_deploy.py b/wren-ai-service/src/force_deploy.py index 3cba215a1..ee5c4381f 100644 --- a/wren-ai-service/src/force_deploy.py +++ b/wren-ai-service/src/force_deploy.py @@ -3,11 +3,10 @@ import aiohttp import backoff -from dotenv import load_dotenv -load_dotenv(override=True) -if is_dev_env := os.getenv("ENV") and os.getenv("ENV").lower() == "dev": - load_dotenv(".env.dev", override=True) +from src.utils import load_env_vars + +load_env_vars() @backoff.on_exception(backoff.expo, aiohttp.ClientError, max_time=60, max_tries=3) @@ -25,5 +24,5 @@ async def force_deploy(): print(f"Forcing deployment: {res}") -if os.getenv("ENGINE", "wren-ui") == "wren-ui": +if os.getenv("ENGINE", "wren_ui") == "wren_ui": asyncio.run(force_deploy()) diff --git a/wren-ai-service/src/providers/embedder/azure_openai.py b/wren-ai-service/src/providers/embedder/azure_openai.py index ca0c7e7f4..ba9506caf 100644 --- a/wren-ai-service/src/providers/embedder/azure_openai.py +++ b/wren-ai-service/src/providers/embedder/azure_openai.py @@ -15,6 +15,7 @@ from src.core.provider import EmbedderProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -200,16 +201,20 @@ def __init__( ) or EMBEDDING_MODEL_DIMENSION, ): - logger.info(f"Using Azure OpenAI Embedding Model: {embedding_model}") - logger.info(f"Using Azure OpenAI Embedding API Base: {embed_api_base}") - logger.info(f"Using Azure OpenAI Embedding API Version: {embed_api_version}") - - self._embedding_api_base = embed_api_base + self._embedding_api_base = remove_trailing_slash(embed_api_base) self._embedding_api_key = embed_api_key self._embedding_api_version = embed_api_version self._embedding_model = embedding_model self._embedding_model_dim = embedding_model_dim + logger.info(f"Using Azure OpenAI Embedding Model: {self._embedding_model}") + logger.info( + f"Using Azure OpenAI Embedding API Base: {self._embedding_api_base}" + ) + logger.info( + f"Using Azure OpenAI Embedding API Version: {self._embedding_api_version}" + ) + def get_text_embedder(self): return AsyncTextEmbedder( api_key=self._embedding_api_key, diff --git a/wren-ai-service/src/providers/embedder/ollama.py b/wren-ai-service/src/providers/embedder/ollama.py index 38207a3ac..de9f19e5d 100644 --- a/wren-ai-service/src/providers/embedder/ollama.py +++ b/wren-ai-service/src/providers/embedder/ollama.py @@ -13,6 +13,7 @@ from src.core.provider import EmbedderProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -163,12 +164,12 @@ def __init__( url: str = os.getenv("EMBEDDER_OLLAMA_URL") or EMBEDDER_OLLAMA_URL, embedding_model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL, ): - logger.info(f"Using Ollama Embedding Model: {embedding_model}") - logger.info(f"Using Ollama URL: {url}") - - self._url = url + self._url = remove_trailing_slash(url) self._embedding_model = embedding_model + logger.info(f"Using Ollama Embedding Model: {self._embedding_model}") + logger.info(f"Using Ollama URL: {self._url}") + def get_text_embedder( self, model_kwargs: Optional[Dict[str, Any]] = None, diff --git a/wren-ai-service/src/providers/embedder/openai.py b/wren-ai-service/src/providers/embedder/openai.py index 936e50628..0941171f6 100644 --- a/wren-ai-service/src/providers/embedder/openai.py +++ b/wren-ai-service/src/providers/embedder/openai.py @@ -12,6 +12,7 @@ from src.core.provider import EmbedderProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -195,19 +196,22 @@ def _verify_api_key(api_key: str, api_base: str) -> None: """ OpenAI(api_key=api_key, base_url=api_base).models.list() - logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}") + self._api_key = api_key + self._api_base = remove_trailing_slash(api_base) + self._embedding_model = embedding_model + self._embedding_model_dim = embedding_model_dim + + logger.info( + f"Initializing OpenAIEmbedder provider with API base: {self._api_base}" + ) # TODO: currently only OpenAI api key can be verified - if api_base == EMBEDDER_OPENAI_API_BASE: - _verify_api_key(api_key.resolve_value(), api_base) - logger.info(f"Using OpenAI Embedding Model: {embedding_model}") + if self._api_base == EMBEDDER_OPENAI_API_BASE: + _verify_api_key(self._api_key.resolve_value(), self._api_base) + logger.info(f"Using OpenAI Embedding Model: {self._embedding_model}") else: logger.info( - f"Using OpenAI API-compatible Embedding Model: {embedding_model}" + f"Using OpenAI API-compatible Embedding Model: {self._embedding_model}" ) - self._api_key = api_key - self._api_base = api_base - self._embedding_model = embedding_model - self._embedding_model_dim = embedding_model_dim def get_text_embedder(self): return AsyncTextEmbedder( diff --git a/wren-ai-service/src/providers/llm/azure_openai.py b/wren-ai-service/src/providers/llm/azure_openai.py index a0c3f3240..d560e7321 100644 --- a/wren-ai-service/src/providers/llm/azure_openai.py +++ b/wren-ai-service/src/providers/llm/azure_openai.py @@ -14,6 +14,7 @@ from src.core.provider import LLMProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -76,7 +77,7 @@ async def run( completion: Union[ Stream[ChatCompletionChunk], ChatCompletion ] = await self.client.chat.completions.create( - model=self.model, + model=self.azure_deployment, messages=openai_formatted_messages, stream=self.streaming_callback is not None, **generation_kwargs, @@ -122,15 +123,17 @@ def __init__( chat_api_version: str = os.getenv("LLM_AZURE_OPENAI_VERSION"), generation_model: str = os.getenv("GENERATION_MODEL") or GENERATION_MODEL, ): - logger.info(f"Using AzureOpenAI LLM: {generation_model}") - logger.info(f"Using AzureOpenAI LLM with API base: {chat_api_base}") - logger.info(f"Using AzureOpenAI LLM with API version: {chat_api_version}") - self._generation_api_key = chat_api_key - self._generation_api_base = chat_api_base + self._generation_api_base = remove_trailing_slash(chat_api_base) self._generation_api_version = chat_api_version self._generation_model = generation_model + logger.info(f"Using AzureOpenAI LLM: {self._generation_model}") + logger.info(f"Using AzureOpenAI LLM with API base: {self._generation_api_base}") + logger.info( + f"Using AzureOpenAI LLM with API version: {self._generation_api_version}" + ) + def get_generator( self, model_kwargs: Dict[str, Any] = ( diff --git a/wren-ai-service/src/providers/llm/ollama.py b/wren-ai-service/src/providers/llm/ollama.py index ab6a5ca11..28b8b01f8 100644 --- a/wren-ai-service/src/providers/llm/ollama.py +++ b/wren-ai-service/src/providers/llm/ollama.py @@ -10,6 +10,7 @@ from src.core.provider import LLMProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -126,11 +127,12 @@ def __init__( url: str = os.getenv("LLM_OLLAMA_URL") or LLM_OLLAMA_URL, generation_model: str = os.getenv("GENERATION_MODEL") or GENERATION_MODEL, ): - logger.info(f"Using Ollama LLM: {generation_model}") - logger.info(f"Using Ollama URL: {url}") - self._url = url + self._url = remove_trailing_slash(url) self._generation_model = generation_model + logger.info(f"Using Ollama LLM: {self._generation_model}") + logger.info(f"Using Ollama URL: {self._url}") + def get_generator( self, model_kwargs: Dict[str, Any] = ( diff --git a/wren-ai-service/src/providers/llm/openai.py b/wren-ai-service/src/providers/llm/openai.py index 0712277c4..cf3ab864b 100644 --- a/wren-ai-service/src/providers/llm/openai.py +++ b/wren-ai-service/src/providers/llm/openai.py @@ -14,6 +14,7 @@ from src.core.provider import LLMProvider from src.providers.loader import provider +from src.utils import remove_trailing_slash logger = logging.getLogger("wren-ai-service") @@ -127,17 +128,18 @@ def _verify_api_key(api_key: str, api_base: str) -> None: """ OpenAI(api_key=api_key, base_url=api_base).models.list() - logger.info(f"Using OpenAILLM provider with API base: {api_base}") - # TODO: currently only OpenAI api key can be verified - if api_base == LLM_OPENAI_API_BASE: - _verify_api_key(api_key.resolve_value(), api_base) - logger.info(f"Using OpenAI LLM: {generation_model}") - else: - logger.info(f"Using OpenAI API-compatible LLM: {generation_model}") self._api_key = api_key - self._api_base = api_base + self._api_base = remove_trailing_slash(api_base) self._generation_model = generation_model + logger.info(f"Using OpenAILLM provider with API base: {self._api_base}") + # TODO: currently only OpenAI api key can be verified + if self._api_base == LLM_OPENAI_API_BASE: + _verify_api_key(self._api_key.resolve_value(), self._api_base) + logger.info(f"Using OpenAI LLM: {self._generation_model}") + else: + logger.info(f"Using OpenAI API-compatible LLM: {self._generation_model}") + def get_generator( self, model_kwargs: Dict[str, Any] = ( diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index 7ba9a5f9d..deeb1f8d6 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -117,3 +117,7 @@ async def wrapper_timer(*args, **kwargs): return await process(func, *args, **kwargs) return wrapper_timer + + +def remove_trailing_slash(endpoint: str) -> str: + return endpoint.rstrip("/") if endpoint.endswith("/") else endpoint