Skip to content

Commit

Permalink
Merge pull request #86 from vanna-ai/plan-generic
Browse files Browse the repository at this point in the history
Generic training plan from information schema (experimental)
  • Loading branch information
zainhoda authored Aug 4, 2023
2 parents fdeef09 + 00fcfed commit b4c2410
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.20"
version = "0.0.21"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
25 changes: 25 additions & 0 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,31 @@ def get_training_plan_postgres(filter_databases: Union[List[str], None] = None,

return plan

def get_training_plan_generic(df) -> TrainingPlan:
# For each of the following, we look at the df columns to see if there's a match:
database_column = df.columns[df.columns.str.lower().str.contains("database") | df.columns.str.lower().str.contains("table_catalog")].to_list()[0]
schema_column = df.columns[df.columns.str.lower().str.contains("table_schema")].to_list()[0]
table_column = df.columns[df.columns.str.lower().str.contains("table_name")].to_list()[0]
column_column = df.columns[df.columns.str.lower().str.contains("column_name")].to_list()[0]
data_type_column = df.columns[df.columns.str.lower().str.contains("data_type")].to_list()[0]

plan = TrainingPlan([])

for database in df[database_column].unique().tolist():
for schema in df.query(f'{database_column} == "{database}"')[schema_column].unique().tolist():
for table in df.query(f'{database_column} == "{database}" and {schema_column} == "{schema}"')[table_column].unique().tolist():
df_columns_filtered_to_table = df.query(f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"')
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[[database_column, schema_column, table_column, column_column, data_type_column]].to_markdown()

plan._plan.append(TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc
))

return plan

def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan:
"""
Expand Down

0 comments on commit b4c2410

Please sign in to comment.