From 472165af2e41b5a1b2d51ffda9f65518bb765d24 Mon Sep 17 00:00:00 2001 From: Daniel Imberman Date: Wed, 16 Feb 2022 11:32:18 -0800 Subject: [PATCH] Allow passing db context via op_kwargs (#106) * Allow passing db context via op_kwargs For queries where users don't want to pass a table, object, this feature will allow users to define context at runtime using op_kwargs. example: ```python @aql.transform def test_astro(): return "SELECT * FROM actor" with dag: actor_table = test_astro(database="pagile", conn_id="my_postgres_conn") ``` * simplify * fix test * fix test --- src/astro/sql/operators/sql_decorator.py | 18 +++++++++-------- tests/operators/test_postgres_decorator.py | 23 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/astro/sql/operators/sql_decorator.py b/src/astro/sql/operators/sql_decorator.py index 48f3f1dd5a..d716db5187 100644 --- a/src/astro/sql/operators/sql_decorator.py +++ b/src/astro/sql/operators/sql_decorator.py @@ -68,14 +68,9 @@ def __init__( :param kwargs: """ self.raw_sql = raw_sql - self.conn_id = conn_id self.autocommit = autocommit self.parameters = parameters - self.database = database - self.schema = schema self.handler = handler - self.warehouse = warehouse - self.role = role self.kwargs = kwargs or {} self.sql = sql self.op_kwargs: Dict = self.kwargs.get("op_kwargs") or {} @@ -84,6 +79,12 @@ def __init__( else: self.output_table = None + self.database = self.op_kwargs.pop("database", database) + self.conn_id = self.op_kwargs.pop("conn_id", conn_id) + self.schema = self.op_kwargs.pop("schema", schema) + self.warehouse = self.op_kwargs.pop("warehouse", warehouse) + self.role = self.op_kwargs.pop("role", role) + super().__init__( **kwargs, ) @@ -221,9 +222,10 @@ def _set_variables_from_first_table(self): # If there is no first table via op_ags or kwargs, we check the parameters if not first_table: - param_tables = [t for t in self.parameters.values() if type(t) == Table] - if param_tables: - first_table = param_tables[0] + if self.parameters: + param_tables = [t for t in self.parameters.values() if type(t) == Table] + if param_tables: + first_table = param_tables[0] if first_table: self.conn_id = first_table.conn_id or self.conn_id diff --git a/tests/operators/test_postgres_decorator.py b/tests/operators/test_postgres_decorator.py index 4519447327..73ba3f4c07 100644 --- a/tests/operators/test_postgres_decorator.py +++ b/tests/operators/test_postgres_decorator.py @@ -195,6 +195,29 @@ def sample_pg(input_table: Table): ) assert df.iloc[0].to_dict()["colors"] == "red" + def test_postgres_set_op_kwargs(self): + self.hook_target = PostgresHook( + postgres_conn_id="postgres_conn", schema="pagila" + ) + + @aql.transform + def sample_pg(): + return "SELECT * FROM actor WHERE last_name LIKE 'G%%'" + + self.create_and_run_task( + sample_pg, + (), + { + "conn_id": "postgres_conn", + "database": "pagila", + }, + ) + df = pd.read_sql( + f"SELECT * FROM tmp_astro.test_dag_sample_pg_1", + con=self.hook_target.get_conn(), + ) + assert df.iloc[0].to_dict()["first_name"] == "PENELOPE" + def test_postgres(self): self.hook_target = PostgresHook( postgres_conn_id="postgres_conn", schema="pagila"