Skip to content

Commit

Permalink
use the latest version of wren-engine
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 10, 2024
1 parent 73cba65 commit d6a6f70
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 77 deletions.
3 changes: 3 additions & 0 deletions wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ WREN_IBIS_SOURCE=bigquery
WREN_IBIS_MANIFEST= # this is a base64 encoded string of the MDL
WREN_IBIS_CONNECTION_INFO={"project_id": "", "dataset_id": "", "credentials":""}

## when using wren_engine as the engine
WREN_ENGINE_MANIFEST=

# evaluation related
DATASET_NAME=book_2

Expand Down
8 changes: 3 additions & 5 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ask_details,
get_mdl_json,
get_new_mdl_json,
prepare_duckdb,
prepare_semantics,
rerun_wren_engine,
save_mdl_json_file,
Expand Down Expand Up @@ -147,11 +146,10 @@ def onchange_demo_dataset():
)
# Semantics preparation
if deploy_ok:
if st.session_state["dataset_type"] == "duckdb":
prepare_duckdb(st.session_state["chosen_dataset"])

rerun_wren_engine(
st.session_state["mdl_json"], st.session_state["dataset_type"]
st.session_state["mdl_json"],
st.session_state["dataset_type"],
st.session_state["chosen_dataset"],
)
prepare_semantics(st.session_state["mdl_json"])

Expand Down
148 changes: 82 additions & 66 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import json
import os
import re
Expand All @@ -13,6 +14,7 @@
from dotenv import load_dotenv

WREN_AI_SERVICE_BASE_URL = "http://localhost:5556"
WREN_IBIS_API_URL = "http://localhost:8000"
WREN_ENGINE_API_URL = "http://localhost:8080"
POLLING_INTERVAL = 0.5
DATA_SOURCES = ["duckdb", "bigquery", "postgres"]
Expand All @@ -29,53 +31,79 @@ def _update_wren_engine_configs(configs: list[dict]):
assert response.status_code == 200


def rerun_wren_engine(mdl_json: Dict, dataset_type: str):
def rerun_wren_engine(mdl_json: Dict, dataset_type: str, dataset: str):
assert dataset_type in DATA_SOURCES

if dataset_type == "bigquery":
BIGQUERY_CREDENTIALS = os.getenv("bigquery.credentials-key")
assert (
BIGQUERY_CREDENTIALS is not None
), "bigquery.credentials-key is not set in .env"

SOURCE = dataset_type
MANIFEST = base64.b64encode(orjson.dumps(mdl_json)).decode()
if dataset_type == "duckdb":
_update_wren_engine_configs(
[
{"name": "wren.datasource.type", "value": "bigquery"},
{
"name": "bigquery.project-id",
"value": os.getenv("bigquery.project-id"),
"name": "duckdb.connector.init-sql-path",
"value": "/usr/src/app/etc/duckdb-init.sql",
},
{"name": "bigquery.location", "value": os.getenv("bigquery.location")},
{"name": "bigquery.credentials-key", "value": BIGQUERY_CREDENTIALS},
]
)
elif dataset_type == "duckdb":
_update_wren_engine_configs(
[{"name": "wren.datasource.type", "value": "duckdb"}]
)
elif dataset_type == "postgresql":
_update_wren_engine_configs(
[
{"name": "wren.datasource.type", "value": "postgres"},
{"name": "postgres.user", "value": os.getenv("postgres.user")},
{"name": "postgres.password", "value": os.getenv("postgres.password")},
{"name": "postgres.jdbc.url", "value": os.getenv("postgres.jdbc.url")},
]
)

st.toast("Wren Engine is being re-run", icon="⏳")

response = requests.post(
f"{WREN_ENGINE_API_URL}/v1/mdl/deploy",
json={
"manifest": mdl_json,
"version": "latest",
},
)

assert response.status_code == 202

st.toast("Wren Engine is ready", icon="🎉")
_prepare_duckdb(dataset)

# replace the values of WREN_ENGINE_xxx to ../.env.dev
with open("../.env.dev", "r") as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("ENGINE"):
lines[i] = "ENGINE=wren_engine\n"
elif line.startswith("WREN_ENGINE_MANIFEST"):
lines[i] = f"WREN_ENGINE_MANIFEST={MANIFEST}\n"
with open("../.env.dev", "w") as f:
f.writelines(lines)
else:
if dataset_type == "bigquery":
WREN_IBIS_CONNECTION_INFO = base64.b64encode(
orjson.dumps(
{
"project_id": os.getenv("bigquery.project-id"),
"dataset_id": os.getenv("bigquery.dataset-id"),
"credentials": os.getenv("bigquery.credentials-key"),
}
)
).decode()
elif dataset_type == "postgres":
WREN_IBIS_CONNECTION_INFO = base64.b64encode(
orjson.dumps(
{
"host": os.getenv("postgres.host"),
"port": int(os.getenv("postgres.port")),
"database": os.getenv("postgres.database"),
"user": os.getenv("postgres.user"),
"password": os.getenv("postgres.password"),
}
)
).decode()

# replace the values of WREN_IBIS_xxx to ../.env.dev
with open("../.env.dev", "r") as f:
lines = f.readlines()
for i, line in enumerate(lines):
if line.startswith("ENGINE"):
lines[i] = "ENGINE=wren_ibis\n"
elif line.startswith("WREN_IBIS_SOURCE"):
lines[i] = f"WREN_IBIS_SOURCE={SOURCE}\n"
elif line.startswith("WREN_IBIS_MANIFEST"):
lines[i] = f"WREN_IBIS_MANIFEST={MANIFEST}\n"
elif (
line.startswith("WREN_IBIS_CONNECTION_INFO")
and dataset_type != "duckdb"
):
lines[
i
] = f"WREN_IBIS_CONNECTION_INFO={WREN_IBIS_CONNECTION_INFO}\n"
with open("../.env.dev", "w") as f:
f.writelines(lines)

# wait for wren-ai-service to restart
time.sleep(5)


def save_mdl_json_file(file_name: str, mdl_json: Dict):
Expand Down Expand Up @@ -115,11 +143,13 @@ def get_new_mdl_json(chosen_models: List[str]):


@st.cache_data
def get_data_from_wren_engine(sql: str):
def get_data_from_wren_engine(sql: str, manifest: Dict):
response = requests.get(
f"{WREN_ENGINE_API_URL}/v1/mdl/preview",
json={
"sql": sql,
"manifest": manifest,
"limit": 100,
},
)

Expand Down Expand Up @@ -312,6 +342,7 @@ def show_asks_details_results(query: str):
st.dataframe(
get_data_from_wren_engine(
st.session_state["preview_sql"],
st.session_state["mdl_json"],
)
)

Expand All @@ -320,7 +351,7 @@ def show_asks_details_results(query: str):
label="SQL Explanation",
key="sql_explanation_btn",
on_click=on_click_sql_explanation_button,
args=[query, sqls, summaries],
args=[query, sqls, summaries, st.session_state["mdl_json"]],
use_container_width=True,
)

Expand All @@ -330,14 +361,15 @@ def on_click_preview_data_button(index: int, full_sqls: List[str]):
st.session_state["preview_sql"] = full_sqls[index]


def get_sql_analysis_results(sqls: List[str]):
def get_sql_analysis_results(sqls: List[str], manifest: Dict):
results = []
for sql in sqls:
print(f"SQL: {sql}")
response = requests.get(
f"{WREN_ENGINE_API_URL}/v1/analysis/sql",
json={
"sql": sql,
"manifest": manifest,
},
)

Expand All @@ -352,8 +384,9 @@ def on_click_sql_explanation_button(
question: str,
sqls: List[str],
summaries: List[str],
manifest: Dict,
):
sql_analysis_results = get_sql_analysis_results(sqls)
sql_analysis_results = get_sql_analysis_results(sqls, manifest)

st.session_state["sql_explanation_question"] = question
st.session_state["sql_analysis_results"] = sql_analysis_results
Expand Down Expand Up @@ -402,7 +435,7 @@ def generate_mdl_metadata(mdl_model_json: dict):
return mdl_model_json


def prepare_duckdb(dataset_name: str):
def _prepare_duckdb(dataset_name: str):
assert dataset_name in ["music", "nba", "ecommerce"]

DATASET_VERSION = "v0.3.0"
Expand Down Expand Up @@ -434,31 +467,14 @@ def prepare_duckdb(dataset_name: str):
""",
}

api_url = "http://localhost:3000/api/graphql"
with open("../tools/dev/etc/duckdb-init.sql", "w") as f:
f.write("")

user_data = {
"properties": {
"displayName": "my-duckdb",
"initSql": init_sqls[dataset_name],
"configurations": {"threads": 8},
"extensions": ["httpfs", "aws"],
},
"type": "DUCKDB",
}

payload = {
"query": """
mutation SaveDataSource($data: DataSourceInput!) {
saveDataSource(data: $data) {
type
properties
}
}
""",
"variables": {"data": user_data},
}
response = requests.put(
f"{WREN_ENGINE_API_URL}/v1/data-source/duckdb/settings/init-sql",
data=init_sqls[dataset_name],
)

response = requests.post(api_url, json=payload)
assert response.status_code == 200


Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def health():
reload=should_reload,
reload_includes=["src/**/*.py", ".env.dev"],
reload_excludes=[
"./demo/*.py"
"./demo/*.py",
], # TODO: add eval folder when evaluation system is ready
workers=1,
loop="uvloop",
Expand Down
13 changes: 11 additions & 2 deletions wren-ai-service/src/providers/engine/wren.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
import os
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -73,7 +74,7 @@ async def dry_run_sql(
return False, res


@provider("wren-engine")
@provider("wren_engine")
class WrenEngine(Engine):
def __init__(self, endpoint: str = os.getenv("WREN_ENGINE_ENDPOINT")):
self._endpoint = endpoint
Expand All @@ -82,10 +83,18 @@ async def dry_run_sql(
self,
sql: str,
session: aiohttp.ClientSession,
properties: Dict[str, Any] = {
"manifest": os.getenv("WREN_ENGINE_MANIFEST"),
},
) -> Tuple[bool, Optional[Dict[str, Any]]]:
async with session.get(
f"{self._endpoint}/v1/mdl/dry-run",
json={"sql": remove_limit_statement(add_quotes(sql)), "limit": 1},
json={
"manifest": orjson.loads(
base64.b64decode(properties.get("manifest", ""))
),
"sql": remove_limit_statement(add_quotes(sql)),
},
) as response:
if response.status == 200:
return True, None
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/tests/pytest/providers/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ def test_get_provider():
provider = loader.get_provider("wren_ibis")
assert provider.__name__ == "WrenIbis"

provider = loader.get_provider("wren-engine")
provider = loader.get_provider("wren_engine")
assert provider.__name__ == "WrenEngine"
4 changes: 2 additions & 2 deletions wren-ai-service/tools/dev/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ IBIS_SERVER_PORT=8000
# version
# CHANGE THIS TO THE LATEST VERSION
WREN_PRODUCT_VERSION=development
WREN_ENGINE_VERSION=0.5.0
WREN_ENGINE_VERSION=nightly
WREN_AI_SERVICE_VERSION=latest
WREN_UI_VERSION=latest
IBIS_SERVER_VERSION=latest
IBIS_SERVER_VERSION=nightly
WREN_BOOTSTRAP_VERSION=latest

# SQL Protocol
Expand Down

0 comments on commit d6a6f70

Please sign in to comment.