Skip to content

Commit

Permalink
Add lumen ai enhancements (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Apr 5, 2024
1 parent 25d62e2 commit af922f2
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 88 deletions.
122 changes: 75 additions & 47 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..pipeline import Pipeline
from ..sources import FileSource, InMemorySource, Source
from ..transforms.sql import SQLTransform, Transform
from ..validation import ValidationError
from ..views import hvPlotUIView
from .embeddings import Embeddings
from .llm import Llm
Expand All @@ -32,7 +33,7 @@ class Agent(Viewer):
embeddings.
"""

debug = param.Boolean(default=True)
debug = param.Boolean(default=False)

embeddings = param.ClassSelector(class_=Embeddings)

Expand All @@ -55,9 +56,18 @@ class Agent(Viewer):
__abstract = True

def __init__(self, **params):
def _exception_handler(exception):
self.interface.send(
f"Sorry I'm unable to handle this: {exception!r}",
user="System",
respond=False,
)

if "interface" not in params:
params["interface"] = ChatInterface(callback=self._chat_invoke)
super().__init__(**params)
if not self.debug:
pn.config.exception_handler = _exception_handler

def _link_code_editor(self, value, callback, language):
code_editor = pn.widgets.CodeEditor(
Expand Down Expand Up @@ -100,6 +110,7 @@ def _link_code_editor(self, value, callback, language):
return tabs

def _chat_invoke(self, contents: list | str, user: str, instance: ChatInterface):
print("-" * 50)
return self.invoke(contents)

def __panel__(self):
Expand Down Expand Up @@ -216,15 +227,41 @@ def enable_add(event):

class ChatAgent(Agent):
"""
The ChatAgent is a general chat agent unrelated to other roles.
Responsible for chatting about high level or simple data exploration and suggesting
ways to get started with data exploration.
"""

system_prompt = param.String(default="Be a helpful chatbot.")
system_prompt = param.String(
default=(
"Be a helpful chatbot to talk about high-level data exploration "
"like their types or suggestions to get started."
)
)

response_model = param.ClassSelector(
default=String, class_=BaseModel, is_instance=False
)

requires = param.List(default=["current_source"], readonly=True)

def _system_prompt_with_context(
self, messages: list | str, context: str = ""
) -> str:
source = memory.get("current_source")
tables = source.get_tables() if source else []
if len(tables) > 1:
context = f"Available tables: {', '.join(tables)}"
else:
memory["current_table"] = table = memory.get("current_table", tables[0])
schema = self._get_schema(memory["current_source"], table)
if schema:
context = f"{table} with schema: {schema}"

system_prompt = self.system_prompt
if context:
system_prompt += f"{system_prompt}\n### CONTEXT: {context}".strip()
return system_prompt


class LumenBaseAgent(Agent):

Expand All @@ -242,20 +279,25 @@ async def _render_component(spec, active):
# store the spec in the cache instead of memory to save tokens
memory["current_spec"] = spec
try:
yield type(component).from_spec(load_yaml(spec))
yield type(component).from_spec(load_yaml(spec)).__panel__()
except Exception as e:
yield pn.pane.Alert(
f"Error rendering component: {e}", alert_type="danger"
f"Error rendering component: {e}. Please undo or continue the conversation.",
alert_type="danger",
)
# maybe offer undo

# layout widgets
spec = yaml.safe_dump(component.to_spec())

component_spec = component.to_spec()
spec = yaml.safe_dump(component_spec)
tabs = self._link_code_editor(spec, _render_component, "yaml")
message_kwargs = dict(value=tabs, user=self.user)
if message:
self.interface.stream(message=message, **message_kwargs)
else:
self.interface.send(respond=False, **message_kwargs)
tabs.active = 1


class TableAgent(LumenBaseAgent):
Expand Down Expand Up @@ -306,36 +348,9 @@ def invoke(self, messages: list | str):
self._render_lumen(pipeline)


class TableListAgent(LumenBaseAgent):
"""
Responsible for listing the available tables if the user's request is vague;
do not use this if user wants a specific table.s
"""

system_prompt = param.String(
default="You are an agent responsible for listing the available tables in the current source."
)

requires = param.List(default=["current_source"], readonly=True)

def answer(self, messages: list | str):
tables = memory["current_source"].get_tables()
if not tables:
return
tables = tuple(table.replace('"', "") for table in tables)
table_bullets = "\n".join(f"- {table}" for table in tables)
self.interface.send(
f"Available tables:\n{table_bullets}", user=self.name, respond=False
)
return tables

def invoke(self, messages: list | str):
self.answer(messages)


class SQLAgent(LumenBaseAgent):
"""
Responsible for generating and modifying SQL queries to answer user questions about statistics.
Responsible for generating and modifying SQL queries to answer user questions about statistics like min/max/average.
"""

system_prompt = param.String(
Expand All @@ -356,18 +371,20 @@ async def _render_sql_result(query, active):
table = memory["current_table"]
source.tables[table] = query
try:
memory["current_pipeline"] = pipeline = Pipeline(
source=source, table=table
)
memory["current_pipeline"] = Pipeline(source=source, table=table)
# Need this separate or else create random memory entry
pipeline = memory["current_pipeline"].__panel__()
yield pipeline
tabs.active = 1
except Exception as e:
memory.pop("current_pipeline")
yield pn.pane.Alert(
f"Error executing SQL query: {e}", alert_type="danger"
f"Error executing SQL query: {e}; please undo or continue the conversation",
alert_type="danger",
)

tabs = self._link_code_editor(query, _render_sql_result, "sql")
self.interface.stream(tabs, user="SQL", replace=True)
tabs.active = 1

def _sql_prompt(self, sql: str, table: str, schema: dict) -> str:
prompt = f"The SQL expression for the table {table!r} is {sql!r}."
Expand Down Expand Up @@ -485,6 +502,7 @@ def _find_transform(
messages,
system=f"{system_prompt}\n{picker_prompt}",
response_model=transform_model,
model="gpt-4",
)

if transform.transform_required:
Expand Down Expand Up @@ -551,12 +569,7 @@ def answer(self, messages: list | str) -> Transform:
pipeline.add_transform(transform)
try:
pipeline._update_data(force=True)
except Exception as e:
self.interface.send(
f"Generated invalid transform resulting in following error: {e}",
user="Exception",
respond=False,
)
except Exception:
pipeline.transforms = pipeline.transforms[:-1]
memory.pop("current_transform")
print(f"{memory=}")
Expand All @@ -576,7 +589,7 @@ class hvPlotAgent(LumenBaseAgent):
"""

system_prompt = param.String(
default="Generate the plot the user requested. Note that x, y, by and groupby arguments may not reference the same columns."
default="Generate the plot the user requested. Note that x, y, by and groupby arguments may not reference the same columns. Be sure to add `download: csv`"
)

requires = param.List(default=["current_pipeline"], readonly=True)
Expand All @@ -593,7 +606,7 @@ def _view_prompt(
prompt += f"The previous view specification was: {memory['current_view']}"
return prompt

def answer(self, messages: list | str) -> Transform:
def answer(self, messages: list | str, retry: bool = True) -> Transform:
pipeline = memory["current_pipeline"]
table = memory["current_table"]
system_prompt = self._system_prompt_with_context(messages)
Expand Down Expand Up @@ -629,6 +642,21 @@ def answer(self, messages: list | str) -> Transform:
# Instantiate
spec = dict(kwargs)
spec["responsive"] = True

try:
spec["type"] = "hvplot_ui"
view.validate(spec)
except ValidationError as e:
if retry:
new_messages = [
{"role": "assistant", "content": f"{spec=}"},
{"role": "user", "content": f"\nThat didn't work; please note: {e}"}
]
messages.extend(new_messages)
print(f"RETRYING...\n\n\n{messages}")
return self.answer(messages, retry=False)

spec.pop("type", None)
memory["current_view"] = dict(spec, type=view.view_type)
if self.debug:
print(f"{self.name} settled on {spec=!r}.")
Expand Down
Loading

0 comments on commit af922f2

Please sign in to comment.