diff --git a/src/astro/sql/operators/sql_decorator.py b/src/astro/sql/operators/sql_decorator.py index 48f3f1dd5..d716db518 100644 --- a/src/astro/sql/operators/sql_decorator.py +++ b/src/astro/sql/operators/sql_decorator.py @@ -68,14 +68,9 @@ def __init__( :param kwargs: """ self.raw_sql = raw_sql - self.conn_id = conn_id self.autocommit = autocommit self.parameters = parameters - self.database = database - self.schema = schema self.handler = handler - self.warehouse = warehouse - self.role = role self.kwargs = kwargs or {} self.sql = sql self.op_kwargs: Dict = self.kwargs.get("op_kwargs") or {} @@ -84,6 +79,12 @@ def __init__( else: self.output_table = None + self.database = self.op_kwargs.pop("database", database) + self.conn_id = self.op_kwargs.pop("conn_id", conn_id) + self.schema = self.op_kwargs.pop("schema", schema) + self.warehouse = self.op_kwargs.pop("warehouse", warehouse) + self.role = self.op_kwargs.pop("role", role) + super().__init__( **kwargs, ) @@ -221,9 +222,10 @@ def _set_variables_from_first_table(self): # If there is no first table via op_ags or kwargs, we check the parameters if not first_table: - param_tables = [t for t in self.parameters.values() if type(t) == Table] - if param_tables: - first_table = param_tables[0] + if self.parameters: + param_tables = [t for t in self.parameters.values() if type(t) == Table] + if param_tables: + first_table = param_tables[0] if first_table: self.conn_id = first_table.conn_id or self.conn_id diff --git a/tests/operators/test_postgres_decorator.py b/tests/operators/test_postgres_decorator.py index 451944732..73ba3f4c0 100644 --- a/tests/operators/test_postgres_decorator.py +++ b/tests/operators/test_postgres_decorator.py @@ -195,6 +195,29 @@ def sample_pg(input_table: Table): ) assert df.iloc[0].to_dict()["colors"] == "red" + def test_postgres_set_op_kwargs(self): + self.hook_target = PostgresHook( + postgres_conn_id="postgres_conn", schema="pagila" + ) + + @aql.transform + def sample_pg(): + return "SELECT * FROM actor WHERE last_name LIKE 'G%%'" + + self.create_and_run_task( + sample_pg, + (), + { + "conn_id": "postgres_conn", + "database": "pagila", + }, + ) + df = pd.read_sql( + f"SELECT * FROM tmp_astro.test_dag_sample_pg_1", + con=self.hook_target.get_conn(), + ) + assert df.iloc[0].to_dict()["first_name"] == "PENELOPE" + def test_postgres(self): self.hook_target = PostgresHook( postgres_conn_id="postgres_conn", schema="pagila"