Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Only cast if necessary and add min_max
Browse files Browse the repository at this point in the history
ahuang11 committed Jan 31, 2025
1 parent 4dd85dc commit 839e4e5
Showing 3 changed files with 34 additions and 29 deletions.
37 changes: 19 additions & 18 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
@@ -503,9 +503,10 @@ class SQLAgent(LumenBaseAgent):
async def _create_valid_sql(
self,
messages: list[Message],
system: str,
tables_to_source: dict[str, BaseSQLSource],
dialect: str,
comments: str,
title: str,
tables_to_source: dict[str, BaseSQLSource],
errors=None
):
if errors:
@@ -529,9 +530,20 @@ async def _create_valid_sql(
messages = mutate_user_message(content, messages)
log_debug("\033[91mRetry SQLAgent\033[0m")

join_required = len(tables_to_source) > 1
comments = comments if join_required else "" # comments are about joins
system_prompt = await self._render_prompt(
"main",
messages,
join_required=join_required,
tables_sql_schemas=self._memory["tables_sql_schemas"],
dialect=dialect,
comments=comments,
has_errors=bool(errors),
)
with self.interface.add_step(title=title or "SQL query", steps_layout=self._steps_layout) as step:
model_spec = self.prompts["main"].get("llm_spec", "default")
response = self.llm.stream(messages, system=system, model_spec=model_spec, response_model=self._get_model("main"))
response = self.llm.stream(messages, system=system_prompt, model_spec=model_spec, response_model=self._get_model("main"))
sql_query = None
try:
async for output in response:
@@ -671,7 +683,7 @@ async def respond(
for source_table, source in tables_to_source.items():
# Look up underlying table name
source_table = self._drop_source_table_separator(source_table)
table_schema = await get_schema(source, source_table, include_min_max=True, include_count=True)
table_schema = await get_schema(source, source_table, include_count=True)
table_name = source.normalize_table(source_table)
if (
'tables' in source.param and
@@ -687,22 +699,11 @@ async def respond(
self._memory["tables_sql_schemas"] = tables_sql_schemas

dialect = source.dialect
join_required = len(tables_to_source) > 1
comments = comments if join_required else "" # comments are about joins
system_prompt = await self._render_prompt(
"main",
messages,
join_required=join_required,
tables_sql_schemas=tables_sql_schemas,
dialect=dialect,
comments=comments,
)
if join_required:
messages[-1]["content"] = self._drop_source_table_separator(messages[-1]["content"])
try:
sql_query = await self._create_valid_sql(messages, system_prompt, tables_to_source, step_title)
sql_query = await self._create_valid_sql(messages, dialect, comments, step_title, tables_to_source)
pipeline = self._memory['pipeline']
except RetriesExceededError as e:
traceback.print_exception(e)
self._memory["__error__"] = str(e)
return None
self._render_lumen(pipeline, spec=sql_query, messages=messages, render_output=render_output, title=step_title)
@@ -780,7 +781,7 @@ async def respond(
if not pipeline:
raise ValueError("No current pipeline found in memory.")

schema = await get_schema(pipeline, include_min_max=False)
schema = await get_schema(pipeline)
if not schema:
raise ValueError("Failed to retrieve schema for the current pipeline.")

23 changes: 14 additions & 9 deletions lumen/ai/prompts/SQLAgent/main.jinja2
Original file line number Diff line number Diff line change
@@ -20,19 +20,13 @@ Here are YAML schemas for currently relevant tables:
Checklist:
- Quote column names to ensure they do not clash with valid identifiers.
- Do not include comments in the SQL code
- If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g.
CAST or TO_DATE
- Mention example data enums from the schema to ensure the data type and format if necessary
- Use only `{{ dialect }}` SQL syntax
- Try to pretty print the SQL output with newlines and indentation.
- Specify data types explicitly to avoid type mismatches.
- Handle NULL values using functions like COALESCE or IS NULL.
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, only if needed.
- Use parameterized queries to prevent SQL injection attacks.
- Use Common Table Expressions (CTEs) and subqueries to break down complex queries into manageable parts.
- Be sure to remove suspiciously large or small values that may be invalid, like -9999.
- Ensure robust type conversion using functions like TRY_CAST to avoid query failures due to invalid data.
- Filter and sort data efficiently (e.g., ORDER BY key metrics) and use LIMIT to focus on the most relevant results.
- Use Common Table Expressions (CTEs) and subqueries to break down complex queries into manageable parts if complexity warrants it.
- Filter and sort data efficiently (e.g., ORDER BY key metrics) and use LIMIT (greater than 1) to focus on the most relevant results.
{%- if join_required -%}
- Please perform a join between the necessary tables.
- If the join's values do not align based on the min/max lowest common denominator, then perform a join based on the closest match, or resample and aggregate the data to align the values.
@@ -53,10 +47,20 @@ identifiers.
Here's additional guidance:
{{ comments }}
{%- endif -%}

If there are issues with the query, here are some common fixes:
{%- if has_errors %}
- Handle NULL values using functions like COALESCE or IS NULL.
- If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g.
CAST or TO_DATE
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, only if needed.
- Ensure robust type conversion using functions like TRY_CAST to avoid query failures due to invalid data.
{% endif %}
{%- endblock -%}

{%- block examples %}
Examples:
{%- if has_errors -%}
Casting Examples:

If the query is "Which five regions have the highest total sales from 2022-02-22?"...

@@ -96,4 +100,5 @@ WHERE sale_date >= '2022-02-22'
GROUP BY region
ORDER BY total_sales DESC;
```
{%- endif -%}
{% endblock -%}
3 changes: 1 addition & 2 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@
)
from markupsafe import escape

from lumen.ai.config import SOURCE_TABLE_SEPARATOR
from lumen.pipeline import Pipeline
from lumen.sources.base import Source
from lumen.sources.duckdb import DuckDBSource
@@ -370,7 +369,7 @@ async def gather_table_sources(sources: list[Source], include_provided: bool = T
label = f"{SOURCE_TABLE_SEPARATOR}{source}{SOURCE_TABLE_SEPARATOR}{table}" if include_sep else table
if isinstance(source, DuckDBSource) and source.ephemeral or "Provided" in source.name:
sql = source.get_sql_expr(table)
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=3)
schema = await get_schema(source, table, include_min_max=True, include_enum=True, limit=3)
tables_schema_str += f"- {label}\nSchema:\n```yaml\n{yaml.dump(schema)}```\nSQL:\n```sql\n{sql}\n```\n\n"
else:
tables_schema_str += f"- {label}\n\n"

0 comments on commit 839e4e5

Please sign in to comment.