Skip to content

Commit

Permalink
pull ollama models if not exist automatically (#512)
Browse files Browse the repository at this point in the history
* pull ollama models if not exist automatically

* add ollama_url to pull_ollama_model

* fix bug

* update message for Custom llm provider and timeout for ai service

* update message

* fix intro for WREN_IBIS_CONNECTION_INFO
  • Loading branch information
cyyeh authored Jul 15, 2024
1 parent 36506cd commit 14c396d
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 184 deletions.
2 changes: 1 addition & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ WREN_UI_ENDPOINT=http://localhost:3000
WREN_IBIS_ENDPOINT=http://localhost:8000
WREN_IBIS_SOURCE=bigquery
WREN_IBIS_MANIFEST= # this is a base64 encoded string of the MDL
WREN_IBIS_CONNECTION_INFO={"project_id": "", "dataset_id": "", "credentials":""}
WREN_IBIS_CONNECTION_INFO= # this is a base64 encode string of the connection info

# evaluation related
DATASET_NAME=book_2
Expand Down
366 changes: 190 additions & 176 deletions wren-ai-service/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ sf-hamilton = {version = "==1.63.0", extras = ["visualization"]}
aiohttp = "==3.9.5"
ollama-haystack = "==0.0.6"
langfuse = "==2.35.0"
ollama = "==0.2.1"

[tool.poetry.group.dev.dependencies]
pytest = "==8.2.0"
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def health():
host=server_host,
port=server_port,
reload=should_reload,
reload_dirs=["src"] if should_reload else None,
reload_dirs=["src"],
reload_includes=[".env.dev"],
workers=1,
loop="uvloop",
http="httptools",
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/providers/embedder/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from src.core.provider import EmbedderProvider
from src.providers.loader import provider
from src.providers.loader import provider, pull_ollama_model
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -167,6 +167,8 @@ def __init__(
self._url = remove_trailing_slash(url)
self._embedding_model = embedding_model

pull_ollama_model(self._url, self._embedding_model)

logger.info(f"Using Ollama Embedding Model: {self._embedding_model}")
logger.info(f"Using Ollama URL: {self._url}")

Expand Down
7 changes: 5 additions & 2 deletions wren-ai-service/src/providers/engine/wren.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
import os
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -54,8 +55,10 @@ async def dry_run_sql(
"source": os.getenv("WREN_IBIS_SOURCE"),
"manifest": os.getenv("WREN_IBIS_MANIFEST"),
"connection_info": orjson.loads(
os.getenv("WREN_IBIS_CONNECTION_INFO", "{}")
),
base64.b64decode(os.getenv("WREN_IBIS_CONNECTION_INFO"))
)
if os.getenv("WREN_IBIS_CONNECTION_INFO")
else {},
},
) -> Tuple[bool, Optional[Dict[str, Any]]]:
async with session.post(
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/providers/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from haystack_integrations.components.generators.ollama import OllamaGenerator

from src.core.provider import LLMProvider
from src.providers.loader import provider
from src.providers.loader import provider, pull_ollama_model
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -130,6 +130,8 @@ def __init__(
self._url = remove_trailing_slash(url)
self._generation_model = generation_model

pull_ollama_model(self._url, self._generation_model)

logger.info(f"Using Ollama LLM: {self._generation_model}")
logger.info(f"Using Ollama URL: {self._url}")

Expand Down
18 changes: 18 additions & 0 deletions wren-ai-service/src/providers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import pkgutil

from ollama import Client

logger = logging.getLogger("wren-ai-service")


Expand Down Expand Up @@ -98,3 +100,19 @@ def get_default_embedding_model_dim(embedder_provider: str):
return importlib.import_module(
f"src.providers.embedder.{file_name}"
).EMBEDDING_MODEL_DIMENSION


def pull_ollama_model(url: str, model_name: str):
client = Client(host=url)
models = client.list()["models"]
if model_name not in models:
logger.info(f"Pulling Ollama model {model_name}")
percentage = 0
for progress in client.pull(model_name, stream=True):
if "completed" in progress and "total" in progress:
new_percentage = int(progress["completed"] / progress["total"] * 100)
if new_percentage > percentage:
percentage = new_percentage
logger.info(f"Pulling Ollama model {model_name}: {percentage}%")
else:
logger.info(f"Ollama model {model_name} already exists")
9 changes: 7 additions & 2 deletions wren-launcher/commands/launch.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ func Launch() {
panic(err)
}

// wait for 10 seconds
pterm.Info.Println("Wren AI is starting, please wait for a moment...")
if llmProvider == "Custom" {
pterm.Info.Println("If you choose Ollama as LLM provider, please make sure you have started the Ollama service first. Also, Wren AI will automatically pull your chosen models if you have not done so. You can check the progress by executing `docker logs -f wrenai-wren-ai-service-1` in the terminal.")
}
url := fmt.Sprintf("http://localhost:%d", uiPort)
// wait until checking if CheckWrenAIStarted return without error
// wait until checking if CheckUIServiceStarted return without error
// if timeout 2 minutes, panic
timeoutTime := time.Now().Add(2 * time.Minute)
for {
Expand All @@ -230,6 +232,9 @@ func Launch() {
time.Sleep(5 * time.Second)
}

// wait until checking if CheckWrenAIStarted return without error
// if timeout 30 minutes, panic
timeoutTime = time.Now().Add(30 * time.Minute)
for {
if time.Now().After(timeoutTime) {
panic("Timeout")
Expand Down

0 comments on commit 14c396d

Please sign in to comment.