Skip to content

Commit

Permalink
Merge pull request #120 from vanna-ai/sync-functions
Browse files Browse the repository at this point in the history
Sync functions between local and remote
  • Loading branch information
zainhoda authored Sep 22, 2023
2 parents ab7c6f5 + b4efdf1 commit fe7ce17
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
end
subgraph OpenAI_Chat
get_prompt
get_sql_prompt
submit_prompt
generate_question
generate_plotly_code
Expand Down
57 changes: 53 additions & 4 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import re

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand All @@ -21,11 +22,11 @@ def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False

def generate_sql_from_question(self, question: str, **kwargs) -> str:
def generate_sql(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_prompt(
prompt = self.get_sql_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
Expand All @@ -35,6 +36,35 @@ def generate_sql_from_question(self, question: str, **kwargs) -> str:
llm_response = self.submit_prompt(prompt, **kwargs)
return llm_response

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_followup_questions_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)

numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")

def generate_questions(self, **kwargs) -> list[str]:
"""
**Example:**
```python
vn.generate_questions()
```
Generate a list of questions that you can ask Vanna.AI.
"""
question_sql = self.get_similar_question_sql(question="", **kwargs)

return [q['question'] for q in question_sql]

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> list[float]:
Expand Down Expand Up @@ -65,10 +95,18 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
def add_documentation(self, doc: str, **kwargs) -> str:
pass

@abstractmethod
def get_training_data(self, **kwargs) -> pd.DataFrame:
pass

@abstractmethod
def remove_training_data(id: str, **kwargs) -> bool:
pass

# ----------------- Use Any Language Model API ----------------- #

@abstractmethod
def get_prompt(
def get_sql_prompt(
self,
question: str,
question_sql_list: list,
Expand All @@ -78,6 +116,17 @@ def get_prompt(
):
pass

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
pass

@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
pass
Expand Down Expand Up @@ -415,7 +464,7 @@ def ask(
question = input("Enter a question: ")

try:
sql = self.generate_sql_from_question(question=question)
sql = self.generate_sql(question=question)
except Exception as e:
print(e)
return None, None, None
Expand Down
92 changes: 86 additions & 6 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -39,32 +40,111 @@ def generate_embedding(self, data: str, **kwargs) -> list[float]:
return embedding[0]
return embedding

def add_question_sql(self, question: str, sql: str, **kwargs):
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
}
)
id = str(uuid.uuid4())+"-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=str(uuid.uuid4()),
ids=id,
)

def add_ddl(self, ddl: str, **kwargs):
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=str(uuid.uuid4()),
ids=id,
)
return id

def add_documentation(self, doc: str, **kwargs):
def add_documentation(self, doc: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-doc"
self.documentation_collection.add(
documents=doc,
embeddings=self.generate_embedding(doc),
ids=str(uuid.uuid4()),
ids=id,
)
return id

def get_training_data(self, **kwargs) -> pd.DataFrame:
sql_data = self.sql_collection.get()

df = pd.DataFrame()

if sql_data is not None:
# Extract the documents and ids
documents = [json.loads(doc) for doc in sql_data['documents']]
ids = sql_data['ids']

# Create a DataFrame
df_sql = pd.DataFrame({
'id': ids,
'question': [doc['question'] for doc in documents],
'content': [doc['sql'] for doc in documents]
})

df_sql["training_data_type"] = "sql"

df = pd.concat([df, df_sql])

ddl_data = self.ddl_collection.get()

if ddl_data is not None:
# Extract the documents and ids
documents = [doc for doc in ddl_data['documents']]
ids = ddl_data['ids']

# Create a DataFrame
df_ddl = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})

df_ddl["training_data_type"] = "ddl"

df = pd.concat([df, df_ddl])

doc_data = self.documentation_collection.get()

if doc_data is not None:
# Extract the documents and ids
documents = [doc for doc in doc_data['documents']]
ids = doc_data['ids']

# Create a DataFrame
df_doc = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})

df_doc["training_data_type"] = "documentation"

df = pd.concat([df, df_doc])

return df

def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.sql_collection.delete(ids=id)
return True
elif id.endswith("-ddl"):
self.ddl_collection.delete(ids=id)
return True
elif id.endswith("-doc"):
self.documentation_collection.delete(ids=id)
return True
else:
return False

# Static method to extract the documents from the results of a query
@staticmethod
Expand Down
82 changes: 66 additions & 16 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod

import openai
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -37,29 +38,56 @@ def user_message(message: str) -> dict:
def assistant_message(message: str) -> dict:
return {"role": "assistant", "content": message}

def get_prompt(
@staticmethod
def str_to_approx_token_count(string: str) -> int:
return len(string) / 4

@staticmethod
def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens:
initial_prompt += f"{ddl}\n\n"

return initial_prompt

@staticmethod
def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens:
initial_prompt += f"{documentation}\n\n"

return initial_prompt

@staticmethod
def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens:
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"

return initial_prompt

def get_sql_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
) -> str:
):
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"

if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

for ddl in ddl_list:
if len(initial_prompt) < 50000: # Add DDL if it fits
initial_prompt += f"{ddl}\n\n"

if len(doc_list) > 0:
initial_prompt += f"The following information may or may not be useful in constructing the SQL to answer the question\n"

for doc in doc_list:
if len(initial_prompt) < 60000: # Add Documentation if it fits
initial_prompt += f"{doc}\n\n"
initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]

Expand All @@ -75,6 +103,28 @@ def get_prompt(

return message_log

def get_followup_questions_prompt(
self,
question: str,
df: pd.DataFrame,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
initial_prompt = f"The user initially asked the question: '{question}': \n\n"

initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]
message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."))

return message_log

def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
Expand Down Expand Up @@ -150,7 +200,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
len(message["content"]) / 4
) # Use 4 as an approximation for the number of characters per token

if "engine" in self.config:
if self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
Expand All @@ -161,7 +211,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
stop=None,
temperature=0.7,
)
elif "model" in self.config:
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
Expand Down
Loading

0 comments on commit fe7ce17

Please sign in to comment.