From 0781dde1f1958ceacf9f3c5bd725893e59c5e263 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Fri, 4 Aug 2023 14:05:06 -0400 Subject: [PATCH 1/2] generic training plan --- src/vanna/__init__.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index c4d52b99..2301c507 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -639,6 +639,31 @@ def get_training_plan_postgres(filter_databases: Union[List[str], None] = None, return plan +def get_training_plan_generic(df) -> TrainingPlan: + # For each of the following, we look at the df columns to see if there's a match: + database_column = df.columns[df.columns.str.lower().str.contains("database") | df.columns.str.lower().str.contains("table_catalog")].to_list()[0] + schema_column = df.columns[df.columns.str.lower().str.contains("table_schema")].to_list()[0] + table_column = df.columns[df.columns.str.lower().str.contains("table_name")].to_list()[0] + column_column = df.columns[df.columns.str.lower().str.contains("column_name")].to_list()[0] + data_type_column = df.columns[df.columns.str.lower().str.contains("data_type")].to_list()[0] + + plan = TrainingPlan([]) + + for database in df[database_column].unique().tolist(): + for schema in df.query(f'{database_column} == "{database}"')[schema_column].unique().tolist(): + for table in df.query(f'{database_column} == "{database}" and {schema_column} == "{schema}"')[table_column].unique().tolist(): + df_columns_filtered_to_table = df.query(f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"') + doc = f"The following columns are in the {table} table in the {database} database:\n\n" + doc += df_columns_filtered_to_table[[database_column, schema_column, table_column, column_column, data_type_column]].to_markdown() + + plan._plan.append(TrainingPlanItem( + item_type=TrainingPlanItem.ITEM_TYPE_IS, + item_group=f"{database}.{schema}", + item_name=table, + item_value=doc + )) + + return plan def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan: """ From 00fcfed036948592179c1277c215e6898f9f5089 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Fri, 4 Aug 2023 14:07:32 -0400 Subject: [PATCH 2/2] bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ea4e4b76..0a55a15c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.0.20" +version = "0.0.21" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ]