Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SQL agent #1010

Merged
merged 11 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 87 additions & 191 deletions lumen/ai/agents.py

Large diffs are not rendered by default.

40 changes: 21 additions & 19 deletions lumen/ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ async def _execute_graph_node(self, node: ExecutionNode, messages: list[Message]
mutated_messages = mutate_user_message(custom_message, mutated_messages)
if instruction:
mutate_user_message(
f"-- For context, here's part of the multi-step plan: {instruction!r}, but as the expert, you may need to deviate from it",
f"-- For context, here's part of the multi-step plan: {instruction!r}, but as the expert, you may need to deviate from it if you notice any inconsistencies or issues.",
mutated_messages, suffix=True, wrap=True
)

Expand Down Expand Up @@ -651,30 +651,31 @@ async def _lookup_context(
title="Obtaining additional context...",
user="Assistant"
) as istep:
context = await self.llm.invoke(
response = self.llm.stream(
messages=messages,
system=system,
model_spec=model_spec,
response_model=context_model,
max_retries=3,
)
if getattr(context, 'tables', None):
requested = [t for t in context.tables if t not in provided]
loaded = '\n'.join([f'- {table}' for table in requested])
istep.stream(f'Looking up schemas for following tables:\n\n{loaded}')
table_info += await self._lookup_schemas(tables, requested, provided, cache=schemas)
if getattr(context, 'tools', None):
for tool in context.tools:
tool_messages = list(messages)
if tool.instruction:
mutate_user_message(
f"-- Here are instructions of the context you are to provide: {tool.instruction!r}",
tool_messages, suffix=True, wrap=True, inplace=False
)
response = await tools[tool.name].respond(tool_messages)
if response is not None:
istep.stream(f'{response}\n')
tool_context += f'\n- {response}'
async for output in response:
if getattr(output, 'tables', None):
requested = [t for t in output.tables if t not in provided]
loaded = '\n'.join([f'- {table}' for table in requested])
istep.stream(f'Looking up schemas for following tables:\n\n{loaded}')
table_info += await self._lookup_schemas(tables, requested, provided, cache=schemas)
if getattr(output, 'tools', None):
for tool in output.tools:
tool_messages = list(messages)
if tool.instruction:
mutate_user_message(
f"-- Here are instructions of the context you are to provide: {tool.instruction!r}",
tool_messages, suffix=True, wrap=True, inplace=False
)
response = await tools[tool.name].respond(tool_messages)
if response is not None:
istep.stream(f'{response}\n')
tool_context += f'\n- {response}'
return table_info, tool_context

async def _make_plan(
Expand Down Expand Up @@ -801,6 +802,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s
tool_names = [tool.name for tool in self._tools["__main__"]]
agent_names = [sagent.name[:-5] for sagent in agents.values()]

# provided is already included in table_info
tables, tables_schema_str = await gather_table_sources(self._memory['sources'], include_provided=False)

reason_model, plan_model = self._get_model(
Expand Down
62 changes: 28 additions & 34 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,13 @@

from typing import Literal

from instructor.dsl.partial import PartialLiteralMixin
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo


class JoinRequired(BaseModel):

chain_of_thought: str = Field(
description="""
Concisely explain whether a table join is required to answer the user's query, or
if the user is requesting a join or merge.
"""
)

requires_joins: bool = Field(description="Whether a table join is required to answer the user's query.")


class TableJoins(BaseModel):

chain_of_thought: str = Field(
description="""
Concisely consider the tables that need to be joined to answer the user's query.
"""
)

tables_to_join: list[str] = Field(
description=(
"List of tables that need to be joined to answer the user's query. "
"Use table names verbatim; e.g. if table is `read_csv('table.csv')` "
"then use `read_csv('table.csv')` and not `table`, but if the table has "
"no extension, i.e. `table`, then use only `table`."
),
)
class PartialBaseModel(BaseModel, PartialLiteralMixin):
...


class Sql(BaseModel):
Expand Down Expand Up @@ -103,7 +78,7 @@ def make_context_model(tools: list[str], tables: list[str]):
description="A list of tools to call to provide context before launching into the planning stage. Use tools to gather additional context or clarification, tools should NEVER be used to obtain the actual data you will be working with."
)
)
return create_model("Context", **fields)
return create_model("Context", __base__=PartialBaseModel, **fields)


def make_plan_models(agents: list[str], tools: list[str]):
Expand Down Expand Up @@ -158,14 +133,33 @@ def make_agent_model(agent_names: list[str], primary: bool = False):
)


def make_table_model(tables):
def make_tables_model(tables):
table_model = create_model(
"Table",
chain_of_thought=(str, FieldInfo(
description="A concise, one sentence decision-tree-style analysis on choosing a table."
description="""
Concisely consider which tables are necessary to answer the user query.
"""
)),
selected_tables=(list[Literal[tuple(tables)]], FieldInfo(
description="""
The most relevant tables based on the user query; if none are relevant,
use the first table. At least one table must be provided.
If a join is necessary, include all the tables that will be used in the join.
"""
)),
relevant_table=(Literal[tuple(tables)], FieldInfo(
description="The most relevant table based on the user query; if none are relevant, select the first. Table names MUST match verbatim including the quotations, apostrophes, periods, or lack thereof."
))
potential_join_issues=(str, FieldInfo(
description="""
If no join is necessary, return an empty string--else
list potential join issues between tables:
- Data type mismatches (int vs string, numeric precision)
- Format differences (case, leading zeros, dates/times, timezones)
- Semantic differences (IDs vs names, codes vs full text)
- Quality issues (nulls, duplicates, validation rules)
Return specific issues found in current tables, and how you plan to address them
in the most easiest, but accurate way possible.
"""
)),
__base__=PartialBaseModel
)
return table_model
2 changes: 2 additions & 0 deletions lumen/ai/prompts/Planner/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ Here's a list of tools:
{{ memory['sql'] }}
```
{%- endif %}

{%- if 'table' in memory %}
- The result of the previous step was the `{{ memory['table'] }}` table. If the user is referencing a previous result this is probably what they're referring to. Consider carefully if it contains all the information you need and only invoke the SQL agent if some other calculation needs to be performed.
- However, if the user requests to see all the columns, they might be referring to the table that `{{ memory['table'] }} was derived from.
- If you are invoking a SQL agent and reusing the table, tell it to reference that table by name rather than re-stating the query.
{%- endif %}

Expand Down
19 changes: 0 additions & 19 deletions lumen/ai/prompts/SQLAgent/find_joins.jinja2

This file was deleted.

12 changes: 12 additions & 0 deletions lumen/ai/prompts/SQLAgent/find_tables.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{% extends 'Actor/main.jinja2' %}

{% block instructions %}
Correctly assess and consider which tables are necessary to answer the user query.

Use table names verbatim, and be sure to include the delimiters {{ separator }}, like '{{ separator }}source{{ separator }}table{{ separator }}'
{% endblock %}

{% block context %}
Available tables and schemas:
{{ tables_schema_str }}
{% endblock %}
63 changes: 35 additions & 28 deletions lumen/ai/prompts/SQLAgent/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{%- block instructions %}
You are an agent responsible for writing a SQL query that will perform the data transformations the user requested.
Try not to take the query too literally, but instead focus on the user's intent and the data transformations required.
Use `SELECT * FROM table` if there is no specific column selection mentioned in the query.
Use `SELECT * FROM table` if there is no specific column selection mentioned in the query; no alias required.
{%- endblock -%}

{% block context -%}
Expand All @@ -14,50 +14,56 @@ Here are YAML schemas for currently relevant tables:
```yaml
{{ details.schema }}
```

{% endfor -%}
{%- endif -%}

{%- if not join_required -%}
It was already determined that no join is required, so only use the existing table '{{ table }}' to calculate the
result.
{% else %}
Please perform a join between the necessary tables.
If the join's values do not align based on the min/max lowest common denominator, then perform a join based on the
closest match, or resample and aggregate the data to align the values.
{%- endif -%}

Checklist:
- Use only `{{ dialect }}` SQL syntax.
- Do NOT include inlined comments in the SQL code, e.g. `-- comment`
- Quote column names to ensure they do not clash with valid identifiers.
- Do not include comments in the SQL code
- If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g.
CAST or TO_DATE
- Mention example data enums from the schema to ensure the data type and format if necessary
- Use only `{{ dialect }}` SQL syntax
- Try to pretty print the SQL output with newlines and indentation.
- Specify data types explicitly to avoid type mismatches.
- Handle NULL values using functions like COALESCE or IS NULL.
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, only if needed.
- Use parameterized queries to prevent SQL injection attacks.
- Use Common Table Expressions (CTEs) and subqueries to break down complex queries into manageable parts.
- Be sure to remove suspiciously large or small values that may be invalid, like -9999.
- Ensure robust type conversion using functions like TRY_CAST to avoid query failures due to invalid data.
- Filter and sort data efficiently (e.g., ORDER BY key metrics) and use LIMIT to focus on the most relevant results.
{% if dialect == 'duckdb' %}
- Pretty print the SQL output with newlines and indentation.
{%- if join_required -%}
- Please perform a join between the necessary tables.
- If the join's values do not align based on the min/max lowest common denominator, then perform a join based on the closest match, or resample and aggregate the data to align the values.
- Very important to transform the values to ensure they align correctly, especially for acronyms and dates.
{%- endif -%}
{%- if dialect == 'duckdb' %}
- If the table name originally did not have `read_*` prefix, use the original table name
- Use table names verbatim; e.g. if table is read_csv('table.csv') then use read_csv('table.csv') and not 'table' or 'table.csv'
- If `read_*` is used, use with alias, e.g. read_parquet('table.parq') as table_parq
- String literals are delimited using single quotes (', apostrophe) and result in STRING_LITERAL values. Note that
double quotes (") cannot be used as string delimiter character: instead, double quotes are used to delimit quoted
identifiers.
{% endif %}
{% if dialect == 'snowflake' %}
{%- if dialect == 'snowflake' %}
- Do not under any circumstances add quotes around the database, schema or table name.
{% endif -%}

Additionally, only if applicable:
- Specify data types explicitly to avoid type mismatches.
- Be sure to remove suspiciously large or small values that may be invalid, like -9999.
- Use Common Table Expressions (CTEs) and subqueries to break down into manageable parts, only if the query requires more than one transformation.
- Filter and sort data efficiently (e.g., ORDER BY key metrics) and use LIMIT (greater than 1) to focus on the most relevant results.
- If the date columns are separated, e.g. year, month, day, then join them into a single date column.

{%- if has_errors %}
If there are issues with the query, here are some common fixes:
- Handle NULL values using functions like COALESCE or IS NULL.
- If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g.
CAST or TO_DATE
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, only if needed.
- Ensure robust type conversion using functions like TRY_CAST to avoid query failures due to invalid data.
{% endif %}
{%- endblock -%}

{% if comments is defined -%}
Here's additional guidance:
{{ comments }}
{%- endif -%}

{%- block examples %}
Examples:
{%- if has_errors -%}
Casting Examples:

If the query is "Which five regions have the highest total sales from 2022-02-22?"...

Expand Down Expand Up @@ -97,4 +103,5 @@ WHERE sale_date >= '2022-02-22'
GROUP BY region
ORDER BY total_sales DESC;
```
{%- endif -%}
{% endblock -%}
22 changes: 0 additions & 22 deletions lumen/ai/prompts/SQLAgent/require_joins.jinja2

This file was deleted.

12 changes: 0 additions & 12 deletions lumen/ai/prompts/SQLAgent/select_table.jinja2

This file was deleted.

3 changes: 2 additions & 1 deletion lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ async def gather_table_sources(sources: list[Source], include_provided: bool = T
"""
Get a dictionary of tables to their respective sources
and a markdown string of the tables and their schemas.

Parameters
----------
sources : list[Source]
Expand All @@ -370,7 +371,7 @@ async def gather_table_sources(sources: list[Source], include_provided: bool = T
label = f"{SOURCE_TABLE_SEPARATOR}{source}{SOURCE_TABLE_SEPARATOR}{table}" if include_sep else table
if isinstance(source, DuckDBSource) and source.ephemeral or "Provided" in source.name:
sql = source.get_sql_expr(table)
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=3)
schema = await get_schema(source, table, include_enum=True, limit=3)
tables_schema_str += f"- {label}\nSchema:\n```yaml\n{yaml.dump(schema)}```\nSQL:\n```sql\n{sql}\n```\n\n"
else:
tables_schema_str += f"- {label}\n\n"
Expand Down
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ COVERAGE_CORE = "sysmon"
[feature.ai.dependencies]
duckdb = "*"
griffe = "*"
instructor = ">=1.4.3"
instructor = ">=1.6.4"
markitdown = "*"
nbformat = "*"
openai = "*"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ HoloViz = "https://holoviz.org/"
[project.optional-dependencies]
tests = ['pytest', 'pytest-rerunfailures', 'pytest-asyncio']
sql = ['duckdb', 'intake-sql', 'sqlalchemy']
ai = ['griffe', 'nbformat', 'duckdb', 'pyarrow', 'instructor >=1.4.3', 'pydantic >=2.8.0', 'pydantic-extra-types', 'panel-graphic-walker[kernel] >=0.5.3', 'markitdown']
ai = ['griffe', 'nbformat', 'duckdb', 'pyarrow', 'instructor >=1.6.4', 'pydantic >=2.8.0', 'pydantic-extra-types', 'panel-graphic-walker[kernel] >=0.5.3', 'markitdown']
ai-local = ['lumen[ai]', 'huggingface_hub']
ai-openai = ['lumen[ai]', 'openai']
ai-mistralai = ['lumen[ai]', 'mistralai']
Expand Down
Loading