From 16a50fd40e6a778992a2549b0fca035236a43a9c Mon Sep 17 00:00:00 2001 From: Daniel Imberman Date: Wed, 16 Feb 2022 11:51:24 -0800 Subject: [PATCH] Don't generate table names with "." (#110) * Don't generate table names with "." Since snowflake is the only DB that can handle table names with periods, we should ensure taht tables we generate don't have periods. * fix tests * fix final test --- src/astro/sql/operators/sql_dataframe.py | 8 +++++-- src/astro/sql/table.py | 4 +++- tests/operators/test_dataframe.py | 24 +++++++++++++++---- .../postgres_simple_tasks/test_astro.sql | 5 ++++ .../test_inheritance.sql | 1 + tests/parsers/test_sql_directory_parser.py | 19 +++++++++++++++ 6 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 tests/parsers/postgres_simple_tasks/test_astro.sql create mode 100644 tests/parsers/postgres_simple_tasks/test_inheritance.sql diff --git a/src/astro/sql/operators/sql_dataframe.py b/src/astro/sql/operators/sql_dataframe.py index 3038d2c8e..571a3a1f6 100644 --- a/src/astro/sql/operators/sql_dataframe.py +++ b/src/astro/sql/operators/sql_dataframe.py @@ -140,9 +140,13 @@ def _get_dataframe(self, table: Table): self.hook = PostgresHook( postgres_conn_id=table.conn_id, schema=table.database ) + schema = table.schema or get_schema() query = ( - sql.SQL("SELECT * FROM {input_table}") - .format(input_table=sql.Identifier(table.table_name)) + sql.SQL("SELECT * FROM {schema}.{input_table}") + .format( + schema=sql.Identifier(schema), + input_table=sql.Identifier(table.table_name), + ) .as_string(self.hook.get_conn()) ) return self.hook.get_pandas_df(query) diff --git a/src/astro/sql/table.py b/src/astro/sql/table.py index 95dccbf7a..eb7d97e9e 100644 --- a/src/astro/sql/table.py +++ b/src/astro/sql/table.py @@ -71,7 +71,9 @@ def to_table(self, table_name: str, schema: str) -> Table: def create_table_name(context): ti: TaskInstance = context["ti"] dag_run: DagRun = ti.get_dagrun() - table_name = f"{dag_run.dag_id}_{ti.task_id}_{dag_run.id}".replace("-", "_") + table_name = f"{dag_run.dag_id}_{ti.task_id}_{dag_run.id}".replace( + "-", "_" + ).replace(".", "__") if not table_name.isidentifier(): table_name = f'"{table_name}"' return table_name diff --git a/tests/operators/test_dataframe.py b/tests/operators/test_dataframe.py index 431d93192..a969abb97 100644 --- a/tests/operators/test_dataframe.py +++ b/tests/operators/test_dataframe.py @@ -117,7 +117,11 @@ def my_df_func(df: pandas.DataFrame): res = self.create_and_run_task( my_df_func, (), - {"df": Table("actor", conn_id="postgres_conn", database="pagila")}, + { + "df": Table( + "actor", conn_id="postgres_conn", database="pagila", schema="public" + ) + }, ) assert ( XCom.get_one( @@ -133,7 +137,11 @@ def my_df_func(df: pandas.DataFrame): res = self.create_and_run_task( my_df_func, - (Table("actor", conn_id="postgres_conn", database="pagila"),), + ( + Table( + "actor", conn_id="postgres_conn", database="pagila", schema="public" + ), + ), {}, ) assert ( @@ -150,8 +158,16 @@ def my_df_func(actor_df: pandas.DataFrame, film_df: pandas.DataFrame): res = self.create_and_run_task( my_df_func, - (Table("actor", conn_id="postgres_conn", database="pagila"),), - {"film_df": Table("film", conn_id="postgres_conn", database="pagila")}, + ( + Table( + "actor", conn_id="postgres_conn", database="pagila", schema="public" + ), + ), + { + "film_df": Table( + "film", conn_id="postgres_conn", database="pagila", schema="public" + ) + }, ) assert ( XCom.get_one( diff --git a/tests/parsers/postgres_simple_tasks/test_astro.sql b/tests/parsers/postgres_simple_tasks/test_astro.sql new file mode 100644 index 000000000..b30cb1d84 --- /dev/null +++ b/tests/parsers/postgres_simple_tasks/test_astro.sql @@ -0,0 +1,5 @@ +--- +conn_id: postgres_conn +database: pagila +--- +SELECT * FROM actor \ No newline at end of file diff --git a/tests/parsers/postgres_simple_tasks/test_inheritance.sql b/tests/parsers/postgres_simple_tasks/test_inheritance.sql new file mode 100644 index 000000000..e6e08704f --- /dev/null +++ b/tests/parsers/postgres_simple_tasks/test_inheritance.sql @@ -0,0 +1 @@ +SELECT * FROM {{test_astro}} LIMIT 10 \ No newline at end of file diff --git a/tests/parsers/test_sql_directory_parser.py b/tests/parsers/test_sql_directory_parser.py index 13fac4da1..db6052e85 100644 --- a/tests/parsers/test_sql_directory_parser.py +++ b/tests/parsers/test_sql_directory_parser.py @@ -110,3 +110,22 @@ def test_parse_creates_xcom(self): rendered_tasks = aql.render(dir_path + "/single_task_dag") test_utils.run_dag(self.dag) + + def test_parse_to_dataframe(self): + """ + Runs two tasks with a direct dependency, the DAG will fail if task two can not inherit the table produced by task 1 + :return: + """ + import pandas as pd + + from astro.dataframe import dataframe as adf + + @adf + def dataframe_func(df: pd.DataFrame): + print(df.to_string) + + with self.dag: + rendered_tasks = aql.render(dir_path + "/postgres_simple_tasks") + dataframe_func(rendered_tasks["test_inheritance"]) + + test_utils.run_dag(self.dag)