Skip to content

Commit

Permalink
chore(wren-ai-service): minor update (#933)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Nov 20, 2024
1 parent 63f0225 commit 220b0ca
Show file tree
Hide file tree
Showing 17 changed files with 2 additions and 189 deletions.
2 changes: 2 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ services:
EMBEDDER_AZURE_OPENAI_API_KEY: ${EMBEDDER_AZURE_OPENAI_API_KEY}
QDRANT_API_KEY: ${QDRANT_API_KEY}
SHOULD_FORCE_DEPLOY: ${SHOULD_FORCE_DEPLOY}
LANGFUSE_SECRET_KEY: ${LANGFUSE_SECRET_KEY}
LANGFUSE_PUBLIC_KEY: ${LANGFUSE_PUBLIC_KEY}
# sometimes the console won't show print messages,
# using PYTHONUNBUFFERED: 1 can fix this
PYTHONUNBUFFERED: 1
Expand Down
3 changes: 0 additions & 3 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from datetime import datetime
from pprint import pformat
from typing import Any, Dict, List, Optional

import aiohttp
Expand Down Expand Up @@ -57,8 +56,6 @@ async def run(
}

sql = self._build_cte_query(steps)
logger.debug(f": steps: {pformat(steps)}")
logger.debug(f"SQLBreakdownGenPostProcessor: final sql: {sql}")

if not await self._check_if_sql_executable(sql, project_id=project_id):
return {
Expand Down
7 changes: 0 additions & 7 deletions wren-ai-service/src/pipelines/generation/data_assistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any, Optional

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand Down Expand Up @@ -57,18 +56,12 @@ def prompt(
language: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"db_schemas: {db_schemas}")
logger.debug(f"language: {language}")

return prompt_builder.run(query=query, db_schemas=db_schemas, language=language)


@async_timer
@observe(as_type="generation", capture_input=False)
async def data_assistance(prompt: dict, generator: Any, query_id: str) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")

return await generator.run(prompt=prompt.get("prompt"), query_id=query_id)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import Any, List

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand Down Expand Up @@ -119,10 +118,6 @@ def prompt(
configurations: AskConfigurations,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
logger.debug(f"history: {history}")
logger.debug(f"configurations: {configurations}")
return prompt_builder.run(
query=query,
documents=documents,
Expand All @@ -136,7 +131,6 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


Expand All @@ -147,9 +141,6 @@ async def post_process(
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate_sql_in_followup: {orjson.dumps(generate_sql_in_followup, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_in_followup.get("replies"), project_id=project_id
)
Expand Down
10 changes: 0 additions & 10 deletions wren-ai-service/src/pipelines/generation/intent_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
@async_timer
@observe(capture_input=False, capture_output=False)
async def embedding(query: str, embedder: Any) -> dict:
logger.debug(f"query: {query}")
return await embedder.run(query)


Expand Down Expand Up @@ -165,27 +164,18 @@ def prompt(
construct_db_schemas: list[str],
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"db_schemas: {construct_db_schemas}")

return prompt_builder.run(query=query, db_schemas=construct_db_schemas)


@async_timer
@observe(as_type="generation", capture_input=False)
async def classify_intent(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")

return await generator.run(prompt=prompt.get("prompt"))


@timer
@observe(capture_input=False)
def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict:
logger.debug(
f"classify_intent: {orjson.dumps(classify_intent, option=orjson.OPT_INDENT_2).decode()}"
)

try:
intent = orjson.loads(classify_intent.get("replies")[0])["results"]
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def prompt(
prompt_builder: PromptBuilder,
language: str,
) -> dict:
logger.debug(f"User prompt: {user_prompt}")
logger.debug(f"Picked models: {picked_models}")
return prompt_builder.run(
picked_models=picked_models,
user_prompt=user_prompt,
Expand All @@ -68,7 +66,6 @@ def prompt(

@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


Expand All @@ -85,10 +82,6 @@ def wrapper(text: str) -> str:
logger.error(f"Error decoding JSON: {e}")
return {"models": []} # Return an empty list if JSON decoding fails

logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
)

reply = generate.get("replies")[0] # Expecting only one reply
normalized = wrapper(reply)

Expand Down
12 changes: 0 additions & 12 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def run(
async def execute_sql(
sql: str, data_fetcher: DataFetcher, project_id: str | None = None
) -> dict:
logger.debug(f"Executing SQL: {sql}")

return await data_fetcher.run(sql=sql, project_id=project_id)


Expand All @@ -128,10 +126,6 @@ def prompt(
language: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
logger.debug(f"sql data: {execute_sql}")
logger.debug(f"language: {language}")
return prompt_builder.run(
query=query,
sql=sql,
Expand All @@ -143,8 +137,6 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_answer(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")

return await generator.run(prompt=prompt.get("prompt"))


Expand All @@ -153,10 +145,6 @@ async def generate_answer(prompt: dict, generator: Any) -> dict:
def post_process(
generate_answer: dict, post_processor: SQLAnswerGenerationPostProcessor
) -> dict:
logger.debug(
f"generate_answer: {orjson.dumps(generate_answer, option=orjson.OPT_INDENT_2).decode()}"
)

return post_processor.run(generate_answer.get("replies"))


Expand Down
9 changes: 0 additions & 9 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import Any

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand Down Expand Up @@ -125,10 +124,6 @@ def prompt(
text_to_sql_rules: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
logger.debug(f"language: {language}")
logger.debug(f"text_to_sql_rules: {text_to_sql_rules}")
return prompt_builder.run(
query=query, sql=sql, language=language, text_to_sql_rules=text_to_sql_rules
)
Expand All @@ -137,7 +132,6 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_details(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


Expand All @@ -148,9 +142,6 @@ async def post_process(
post_processor: SQLBreakdownGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate_sql_details: {orjson.dumps(generate_sql_details, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_details.get("replies"), project_id=project_id
)
Expand Down
15 changes: 0 additions & 15 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any, Dict, List

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack import Document
Expand Down Expand Up @@ -65,12 +64,6 @@ def prompts(
alert: str,
prompt_builder: PromptBuilder,
) -> list[dict]:
logger.debug(
f"documents: {orjson.dumps(documents, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(
f"invalid_generation_results: {orjson.dumps(invalid_generation_results, option=orjson.OPT_INDENT_2).decode()}"
)
return [
prompt_builder.run(
documents=documents,
Expand All @@ -84,10 +77,6 @@ def prompts(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[dict]:
logger.debug(
f"prompts: {orjson.dumps(prompts, option=orjson.OPT_INDENT_2).decode()}"
)

tasks = []
for prompt in prompts:
task = asyncio.ensure_future(generator.run(prompt=prompt.get("prompt")))
Expand All @@ -103,10 +92,6 @@ async def post_process(
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> list[dict]:
logger.debug(
f"generate_sql_corrections: {orjson.dumps(generate_sql_corrections, option=orjson.OPT_INDENT_2).decode()}"
)

return await post_processor.run(generate_sql_corrections, project_id=project_id)


Expand Down
8 changes: 0 additions & 8 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import Any, List

import orjson
from hamilton import base
from hamilton.experimental.h_async import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand Down Expand Up @@ -61,9 +60,6 @@ def prompt(
timezone: AskConfigurations.Timezone,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
logger.debug(f"history: {history}")
return prompt_builder.run(
query=query,
documents=documents,
Expand All @@ -75,7 +71,6 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_expansion(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


Expand All @@ -86,9 +81,6 @@ async def post_process(
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate_sql_expansion: {orjson.dumps(generate_sql_expansion, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_expansion.get("replies"), project_id=project_id
)
Expand Down
29 changes: 0 additions & 29 deletions wren-ai-service/src/pipelines/generation/sql_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,6 @@ def run(
"results"
]

logger.debug(
f"sql_explanation_results: {orjson.dumps(sql_explanation_results, option=orjson.OPT_INDENT_2).decode()}"
)

if preprocessed_sql_analysis_results.get(
"filter", {}
) and sql_explanation_results.get("filter", {}):
Expand Down Expand Up @@ -477,9 +473,6 @@ def run(
def preprocess(
sql_analysis_results: List[dict], pre_processor: SQLAnalysisPreprocessor
) -> dict:
logger.debug(
f"sql_analysis_results: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}"
)
return pre_processor.run(sql_analysis_results)


Expand All @@ -492,13 +485,6 @@ def prompts(
sql_summary: str,
prompt_builder: PromptBuilder,
) -> List[dict]:
logger.debug(f"question: {question}")
logger.debug(f"sql: {sql}")
logger.debug(
f"preprocess: {orjson.dumps(preprocess, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(f"sql_summary: {sql_summary}")

preprocessed_sql_analysis_results_with_values = []
for preprocessed_sql_analysis_result in preprocess[
"preprocessed_sql_analysis_results"
Expand Down Expand Up @@ -534,10 +520,6 @@ def prompts(
}
)

logger.debug(
f"preprocessed_sql_analysis_results_with_values: {orjson.dumps(preprocessed_sql_analysis_results_with_values, option=orjson.OPT_INDENT_2).decode()}"
)

return [
prompt_builder.run(
question=question,
Expand All @@ -552,10 +534,6 @@ def prompts(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_explanation(prompts: List[dict], generator: Any) -> List[dict]:
logger.debug(
f"prompts: {orjson.dumps(prompts, option=orjson.OPT_INDENT_2).decode()}"
)

async def _task(prompt: str, generator: Any):
return await generator.run(prompt=prompt.get("prompt"))

Expand All @@ -570,13 +548,6 @@ def post_process(
preprocess: dict,
post_processor: SQLExplanationGenerationPostProcessor,
) -> dict:
logger.debug(
f"generate_sql_explanation: {orjson.dumps(generate_sql_explanation, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(
f"preprocess: {orjson.dumps(preprocess, option=orjson.OPT_INDENT_2).decode()}"
)

return post_processor.run(
generate_sql_explanation,
preprocess["preprocessed_sql_analysis_results"],
Expand Down
Loading

0 comments on commit 220b0ca

Please sign in to comment.