Skip to content

Commit

Permalink
Merge pull request #139 from vanna-ai/training-plan-base
Browse files Browse the repository at this point in the history
add generic training plan
  • Loading branch information
zainhoda authored Dec 17, 2023
2 parents 8364f3d + 9f33650 commit e6aea6b
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,65 @@ def _get_information_schema_tables(self, database: str) -> pd.DataFrame:

return df_tables

def get_training_plan_generic(self, df) -> TrainingPlan:
# For each of the following, we look at the df columns to see if there's a match:
database_column = df.columns[
df.columns.str.lower().str.contains("database")
| df.columns.str.lower().str.contains("table_catalog")
].to_list()[0]
schema_column = df.columns[
df.columns.str.lower().str.contains("table_schema")
].to_list()[0]
table_column = df.columns[
df.columns.str.lower().str.contains("table_name")
].to_list()[0]
column_column = df.columns[
df.columns.str.lower().str.contains("column_name")
].to_list()[0]
data_type_column = df.columns[
df.columns.str.lower().str.contains("data_type")
].to_list()[0]

plan = TrainingPlan([])

for database in df[database_column].unique().tolist():
for schema in (
df.query(f'{database_column} == "{database}"')[schema_column]
.unique()
.tolist()
):
for table in (
df.query(
f'{database_column} == "{database}" and {schema_column} == "{schema}"'
)[table_column]
.unique()
.tolist()
):
df_columns_filtered_to_table = df.query(
f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"'
)
doc = f"The following columns are in the {table} table in the {database} database:\n\n"
doc += df_columns_filtered_to_table[
[
database_column,
schema_column,
table_column,
column_column,
data_type_column,
]
].to_markdown()

plan._plan.append(
TrainingPlanItem(
item_type=TrainingPlanItem.ITEM_TYPE_IS,
item_group=f"{database}.{schema}",
item_name=table,
item_value=doc,
)
)

return plan

def get_training_plan_snowflake(
self,
filter_databases: Union[List[str], None] = None,
Expand Down

0 comments on commit e6aea6b

Please sign in to comment.