Skip to content

Commit

Permalink
Add table listing (#524)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Rudiger <[email protected]>
  • Loading branch information
ahuang11 and philippjfr authored Mar 13, 2024
1 parent 4cbdca9 commit 5b3aef0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 36 deletions.
106 changes: 83 additions & 23 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from ..base import Component
from ..dashboard import load_yaml
from ..pipeline import Pipeline
from ..sources import FileSource, InMemorySource
from ..sources import FileSource, InMemorySource, Source
from ..transforms.sql import SQLTransform, Transform
from ..views import hvPlotUIView
from .embeddings import Embeddings
from .llm import Llm
from .memory import memory
from .models import Sql, String, Table
from .models import (
Decision, Sql, String, Table,
)
from .translate import param_to_pydantic


Expand Down Expand Up @@ -88,7 +90,7 @@ def _link_code_editor(self, value, callback, language):
)
icons = pn.Row(copy_icon, download_icon)
code_col = pn.Column(code_editor, icons, sizing_mode="stretch_both")
placeholder = pn.Column()
placeholder = pn.Column(sizing_mode="stretch_both")
tabs = pn.Tabs(
("Code", code_col),
("Output", placeholder),
Expand All @@ -104,13 +106,21 @@ def _chat_invoke(self, contents: list | str, user: str, instance: ChatInterface)
def __panel__(self):
return self.interface

def _system_prompt_with_context(self, messages: list | str) -> str:
def _system_prompt_with_context(self, messages: list | str, context: str = "") -> str:
system_prompt = self.system_prompt
if self.embeddings:
context = self.embeddings.query(messages)
if context:
system_prompt += f"{system_prompt}\n### CONTEXT: {context}".strip()
return system_prompt

def _get_schema(self, source: Source, table: str):
try:
source.load_schema = True
return source.get_schema(table)
finally:
source.load_schema = False

def invoke(self, messages: list | str):
message = None
system_prompt = self._system_prompt_with_context(messages)
Expand Down Expand Up @@ -221,10 +231,13 @@ class LumenBaseAgent(Agent):

def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = None):
async def _render_component(spec, active):
if active == 0:
yield pn.indicators.LoadingSpinner(
value=True, name="Rendering component...", height=50, width=50
)
yield pn.indicators.LoadingSpinner(
value=True, name="Rendering component...", height=50, width=50
)

if active != 1:
return

# store the spec in the cache instead of memory to save tokens
memory["current_spec"] = spec
try:
Expand Down Expand Up @@ -269,6 +282,8 @@ def answer(self, messages: list | str):
system_prompt = self._system_prompt_with_context(messages)
if self.debug:
print(f"{self.name} is being instructed that it should {system_prompt}")
# needed or else something with grammar issue
tables = tuple(table.replace('"', "") for table in tables)
table_model = create_model("Table", table=(Literal[tables], ...))
table = self.llm.invoke(
messages,
Expand All @@ -290,9 +305,36 @@ def invoke(self, messages: list | str):
self._render_lumen(pipeline)


class TableListAgent(LumenBaseAgent):
"""
The TableListAgent is responsible for listing the available tables in the current source;
do not use this if user wants a specific table.
"""

system_prompt = param.String(
default="You are an agent responsible for listing the available tables in the current source."
)

requires = param.List(default=["current_source"], readonly=True)

def answer(self, messages: list | str):
tables = memory["current_source"].get_tables()
if not tables:
return
tables = tuple(table.replace('"', "") for table in tables)
table_bullets = "\n".join(f"- {table}" for table in tables)
self.interface.send(
f"Available tables:\n{table_bullets}", user=self.name, respond=False
)
return tables

def invoke(self, messages: list | str):
self.answer(messages)


class SQLAgent(LumenBaseAgent):
"""
The SQLAgent is responsible for modifying SQL queries based on the user prompt.
The SQLAgent is responsible for generating and modifying SQL queries to answer user questions about statistics.
"""

system_prompt = param.String(
Expand All @@ -313,8 +355,11 @@ async def _render_sql_result(query, active):
table = memory["current_table"]
source.tables[table] = query
try:
memory["current_pipeline"] = pipeline = Pipeline(source=source, table=table)
memory["current_pipeline"] = pipeline = Pipeline(
source=source, table=table
)
yield pipeline
tabs.active = 1
except Exception as e:
yield pn.pane.Alert(
f"Error executing SQL query: {e}", alert_type="danger"
Expand All @@ -336,7 +381,7 @@ def answer(self, messages: list | str):
return None
sql_expr = source.get_sql_expr(table)
system_prompt = self._system_prompt_with_context(messages)
schema = source.get_schema(table)
schema = self._get_schema(source, table)
sql_prompt = self._sql_prompt(sql_expr, table, schema)
for chunk in self.llm.stream(
messages,
Expand All @@ -363,7 +408,6 @@ def answer(self, messages: list | str):

def invoke(self, messages: list | str):
sql = self.answer(messages)
raise ValueError(sql)
self._render_sql(sql)


Expand Down Expand Up @@ -396,23 +440,38 @@ def _transform_picker_prompt(self) -> str:
for name, transform in self._available_transforms.items():
if doc := (transform.__doc__ or "").strip():
doc = doc.split("\n\n")[0].strip().replace("\n", "")
prompt += f"- {name}: {doc}\n"
prompt += f"- {name!r}: {doc}\n"
return prompt

def _transform_prompt(
self, model: BaseModel, transform: Transform, table: str, schema: dict
) -> str:
prompt = f"{transform.__doc__}"
prompt += (
f"\n\nThe data follows the following JSON schema:\n\n```json\n{str(schema)}\n```"
)
if not schema:
raise ValueError(f"No schema found for table {table!r}")
else:
print(f"Used schema: {schema}")
prompt += f"\n\nThe data follows the following JSON schema:\n\n```json\n{str(schema)}\n```"
if "current_transform" in memory:
prompt += f"The previous transform specification was: {memory['current_transform']}"
return prompt

def _find_transform(
self, messages: list | str, system_prompt: str
) -> Type[Transform] | None:
decision = self.llm.invoke(
messages,
system=(
"Decide whether a transformation is needed to compute stats or aggregation; "
"if it's just getting data, return False."
),
response_model=Decision,
allow_partial=False,
)
if decision is None or not decision.required:
print(f"{self.name} decided that no transformation is needed because {decision}")
return

picker_prompt = self._transform_picker_prompt()
transforms = self._available_transforms
transform_model = create_model(
Expand All @@ -427,14 +486,14 @@ def _find_transform(
print(
f"{self.name} thought {transform_name=!r} would be the right thing to do."
)
return transforms[transform_name] if transform_name else None
return transforms[transform_name.strip("'")] if transform_name else None

def _construct_transform(
self, messages: list | str, transform: Type[Transform], system_prompt: str
) -> Transform:
table = memory["current_table"]
excluded = transform._internal_params + ["controls", "type"]
schema = memory["current_source"].get_schema(table)
schema = self._get_schema(memory["current_source"], table)
model = param_to_pydantic(transform, excluded=excluded, schema=schema)[
transform.__name__
]
Expand Down Expand Up @@ -474,9 +533,12 @@ def answer(self, messages: list | str) -> Transform:
except Exception as e:
self.interface.send(
f"Generated invalid transform resulting in following error: {e}",
user=self.name,
user="Exception",
respond=False,
)
pipeline.transforms = pipeline.transforms[:-1]
memory.pop("current_transform")
print(f"{memory=}")
pipeline._stale = True
return pipeline

Expand Down Expand Up @@ -505,9 +567,7 @@ def _view_prompt(
) -> str:
doc = view.__doc__.split("\n\n")[0]
prompt = f"{doc}"
prompt += (
f"\n\nThe data follows the following JSON schema:\n\n```json\n{str(schema)}\n```"
)
prompt += f"\n\nThe data follows the following JSON schema:\n\n```json\n{str(schema)}\n```"
if "current_view" in memory:
prompt += f"The previous view specification was: {memory['current_view']}"
return prompt
Expand All @@ -519,7 +579,7 @@ def answer(self, messages: list | str) -> Transform:

# Find parameters
view = hvPlotUIView
schema = pipeline.get_schema()
schema = self._get_schema(pipeline.source, table)
excluded = view._internal_params + [
"controls",
"type",
Expand Down
6 changes: 4 additions & 2 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _invoke_on_init(self):
self.invoke('Initializing chat')

def _generate_picker_prompt(self, agents):
prompt = 'Current you have the following items in memory: {list(memory)}'
prompt += "\nSelect most relevant agent for the user's query:\n" + "\n".join(
# prompt = f'Current you have the following items in memory: {list(memory)}'
prompt = "\nSelect most relevant agent for the user's query:\n" + "\n".join(
f"- {agent.name}: {agent.__doc__.strip()}" for agent in agents
)
return prompt
Expand Down Expand Up @@ -99,6 +99,8 @@ def _get_agent(self, messages: list | str):
agent for agent in self.agents if any(ur in agent.provides for ur in unmet_dependencies)
]
subagent_name = self._choose_agent(messages, subagents)
if subagent_name is None:
continue
subagent = agents[subagent_name]
agent_chain.append((subagent, unmet_dependencies))
if not (unmet_dependencies:= tuple(r for r in subagent.requires if r not in memory)):
Expand Down
16 changes: 6 additions & 10 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from functools import partial

from llama_cpp import ChatCompletionChunk

import panel as pn
import param

Expand Down Expand Up @@ -30,8 +28,8 @@ def __init__(self, model: str | None = None, **params):
params = dict(self._models[model], **params)
else:
raise ValueError(
"No model named {model!r} available. Known models include "
"{', '.join(LLM_METADATA)}."
f"No model named {model!r} available. Known models include "
f"{', '.join(self._models)}."
)
super().__init__(**params)
self._client = None
Expand Down Expand Up @@ -68,20 +66,18 @@ def invoke(
response_model = Partial[response_model]
kwargs['response_model'] = response_model

errored = False
print(messages, "\n\n")
output = None
for r in range(self.retry):
try:
output = client(messages=messages, **kwargs)
break
except Exception as e:
print(f"Error encountered: {e}")
if 'response_model' in kwargs:
errored = True
kwargs['response_model'] = Maybe(response_model)
messages = messages + [{"role": "system", "content": f"You just encountered the following error, make sure you don't repeat it: {e}" }]

if errored:
output = output.result

print(f"Invoked output: {output!r}")
return output

def stream(
Expand Down
5 changes: 5 additions & 0 deletions lumen/ai/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ def __setitem__(self, key, value):
def get(self, key, default=None):
return self._curcontext.get(key, default)

def pop(self, key, default=None):
return self._curcontext.pop(key, default)

def _render_item(self, key, item):
if isinstance(item, Component):
if hasattr(item, "password"):
item.password = "$variables.PASSWORD"
item = item.to_spec()
if isinstance(item, str):
item = f'```yaml\n{item}\n```'
Expand Down
5 changes: 5 additions & 0 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ class Sql(BaseModel):
class Yaml(BaseModel):

spec: str = Field(description="Lumen spec YAML to reflect user query")


class Decision(BaseModel):

required: bool = Field(description="Whether a transformation is required to achieve the user's request")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_setup_version(reponame):
"hvplot",
"holoviews >=1.17.0",
"packaging",
"intake",
"intake <2",
"jinja2 >3.0"
]

Expand Down

0 comments on commit 5b3aef0

Please sign in to comment.