Skip to content

Commit

Permalink
minor update for ai-service (#493)
Browse files Browse the repository at this point in the history
* remove trailing / in urls of env vars

* update

* fix

* fix bug

* simplify codebase
  • Loading branch information
cyyeh authored Jul 9, 2024
1 parent 7047dce commit 6194a10
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 40 deletions.
9 changes: 4 additions & 5 deletions wren-ai-service/src/force_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import aiohttp
import backoff
from dotenv import load_dotenv

load_dotenv(override=True)
if is_dev_env := os.getenv("ENV") and os.getenv("ENV").lower() == "dev":
load_dotenv(".env.dev", override=True)
from src.utils import load_env_vars

load_env_vars()


@backoff.on_exception(backoff.expo, aiohttp.ClientError, max_time=60, max_tries=3)
Expand All @@ -25,5 +24,5 @@ async def force_deploy():
print(f"Forcing deployment: {res}")


if os.getenv("ENGINE", "wren-ui") == "wren-ui":
if os.getenv("ENGINE", "wren_ui") == "wren_ui":
asyncio.run(force_deploy())
15 changes: 10 additions & 5 deletions wren-ai-service/src/providers/embedder/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from src.core.provider import EmbedderProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -200,16 +201,20 @@ def __init__(
)
or EMBEDDING_MODEL_DIMENSION,
):
logger.info(f"Using Azure OpenAI Embedding Model: {embedding_model}")
logger.info(f"Using Azure OpenAI Embedding API Base: {embed_api_base}")
logger.info(f"Using Azure OpenAI Embedding API Version: {embed_api_version}")

self._embedding_api_base = embed_api_base
self._embedding_api_base = remove_trailing_slash(embed_api_base)
self._embedding_api_key = embed_api_key
self._embedding_api_version = embed_api_version
self._embedding_model = embedding_model
self._embedding_model_dim = embedding_model_dim

logger.info(f"Using Azure OpenAI Embedding Model: {self._embedding_model}")
logger.info(
f"Using Azure OpenAI Embedding API Base: {self._embedding_api_base}"
)
logger.info(
f"Using Azure OpenAI Embedding API Version: {self._embedding_api_version}"
)

def get_text_embedder(self):
return AsyncTextEmbedder(
api_key=self._embedding_api_key,
Expand Down
9 changes: 5 additions & 4 deletions wren-ai-service/src/providers/embedder/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from src.core.provider import EmbedderProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -163,12 +164,12 @@ def __init__(
url: str = os.getenv("EMBEDDER_OLLAMA_URL") or EMBEDDER_OLLAMA_URL,
embedding_model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,
):
logger.info(f"Using Ollama Embedding Model: {embedding_model}")
logger.info(f"Using Ollama URL: {url}")

self._url = url
self._url = remove_trailing_slash(url)
self._embedding_model = embedding_model

logger.info(f"Using Ollama Embedding Model: {self._embedding_model}")
logger.info(f"Using Ollama URL: {self._url}")

def get_text_embedder(
self,
model_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
22 changes: 13 additions & 9 deletions wren-ai-service/src/providers/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from src.core.provider import EmbedderProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -195,19 +196,22 @@ def _verify_api_key(api_key: str, api_base: str) -> None:
"""
OpenAI(api_key=api_key, base_url=api_base).models.list()

logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")
self._api_key = api_key
self._api_base = remove_trailing_slash(api_base)
self._embedding_model = embedding_model
self._embedding_model_dim = embedding_model_dim

logger.info(
f"Initializing OpenAIEmbedder provider with API base: {self._api_base}"
)
# TODO: currently only OpenAI api key can be verified
if api_base == EMBEDDER_OPENAI_API_BASE:
_verify_api_key(api_key.resolve_value(), api_base)
logger.info(f"Using OpenAI Embedding Model: {embedding_model}")
if self._api_base == EMBEDDER_OPENAI_API_BASE:
_verify_api_key(self._api_key.resolve_value(), self._api_base)
logger.info(f"Using OpenAI Embedding Model: {self._embedding_model}")
else:
logger.info(
f"Using OpenAI API-compatible Embedding Model: {embedding_model}"
f"Using OpenAI API-compatible Embedding Model: {self._embedding_model}"
)
self._api_key = api_key
self._api_base = api_base
self._embedding_model = embedding_model
self._embedding_model_dim = embedding_model_dim

def get_text_embedder(self):
return AsyncTextEmbedder(
Expand Down
15 changes: 9 additions & 6 deletions wren-ai-service/src/providers/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from src.core.provider import LLMProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -76,7 +77,7 @@ async def run(
completion: Union[
Stream[ChatCompletionChunk], ChatCompletion
] = await self.client.chat.completions.create(
model=self.model,
model=self.azure_deployment,
messages=openai_formatted_messages,
stream=self.streaming_callback is not None,
**generation_kwargs,
Expand Down Expand Up @@ -122,15 +123,17 @@ def __init__(
chat_api_version: str = os.getenv("LLM_AZURE_OPENAI_VERSION"),
generation_model: str = os.getenv("GENERATION_MODEL") or GENERATION_MODEL,
):
logger.info(f"Using AzureOpenAI LLM: {generation_model}")
logger.info(f"Using AzureOpenAI LLM with API base: {chat_api_base}")
logger.info(f"Using AzureOpenAI LLM with API version: {chat_api_version}")

self._generation_api_key = chat_api_key
self._generation_api_base = chat_api_base
self._generation_api_base = remove_trailing_slash(chat_api_base)
self._generation_api_version = chat_api_version
self._generation_model = generation_model

logger.info(f"Using AzureOpenAI LLM: {self._generation_model}")
logger.info(f"Using AzureOpenAI LLM with API base: {self._generation_api_base}")
logger.info(
f"Using AzureOpenAI LLM with API version: {self._generation_api_version}"
)

def get_generator(
self,
model_kwargs: Dict[str, Any] = (
Expand Down
8 changes: 5 additions & 3 deletions wren-ai-service/src/providers/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.provider import LLMProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -126,11 +127,12 @@ def __init__(
url: str = os.getenv("LLM_OLLAMA_URL") or LLM_OLLAMA_URL,
generation_model: str = os.getenv("GENERATION_MODEL") or GENERATION_MODEL,
):
logger.info(f"Using Ollama LLM: {generation_model}")
logger.info(f"Using Ollama URL: {url}")
self._url = url
self._url = remove_trailing_slash(url)
self._generation_model = generation_model

logger.info(f"Using Ollama LLM: {self._generation_model}")
logger.info(f"Using Ollama URL: {self._url}")

def get_generator(
self,
model_kwargs: Dict[str, Any] = (
Expand Down
18 changes: 10 additions & 8 deletions wren-ai-service/src/providers/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from src.core.provider import LLMProvider
from src.providers.loader import provider
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -127,17 +128,18 @@ def _verify_api_key(api_key: str, api_base: str) -> None:
"""
OpenAI(api_key=api_key, base_url=api_base).models.list()

logger.info(f"Using OpenAILLM provider with API base: {api_base}")
# TODO: currently only OpenAI api key can be verified
if api_base == LLM_OPENAI_API_BASE:
_verify_api_key(api_key.resolve_value(), api_base)
logger.info(f"Using OpenAI LLM: {generation_model}")
else:
logger.info(f"Using OpenAI API-compatible LLM: {generation_model}")
self._api_key = api_key
self._api_base = api_base
self._api_base = remove_trailing_slash(api_base)
self._generation_model = generation_model

logger.info(f"Using OpenAILLM provider with API base: {self._api_base}")
# TODO: currently only OpenAI api key can be verified
if self._api_base == LLM_OPENAI_API_BASE:
_verify_api_key(self._api_key.resolve_value(), self._api_base)
logger.info(f"Using OpenAI LLM: {self._generation_model}")
else:
logger.info(f"Using OpenAI API-compatible LLM: {self._generation_model}")

def get_generator(
self,
model_kwargs: Dict[str, Any] = (
Expand Down
4 changes: 4 additions & 0 deletions wren-ai-service/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ async def wrapper_timer(*args, **kwargs):
return await process(func, *args, **kwargs)

return wrapper_timer


def remove_trailing_slash(endpoint: str) -> str:
return endpoint.rstrip("/") if endpoint.endswith("/") else endpoint

0 comments on commit 6194a10

Please sign in to comment.