Skip to content

Commit

Permalink
Bugfix: Pass operator kwargs to dataframe decorator (#632)
Browse files Browse the repository at this point in the history
closes #630

(cherry picked from commit 2b9a3dd)
  • Loading branch information
kaxil committed Aug 16, 2022
1 parent be6280d commit ab220e4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/astro/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,27 @@ def dataframe(
conn_id: str = "",
database: Optional[str] = None,
schema: Optional[str] = None,
task_id: Optional[str] = None,
identifiers_as_lower: Optional[bool] = True,
**kwargs: Any,
) -> Callable[..., pd.DataFrame]:
"""
This decorator will allow users to write python functions while treating SQL tables as dataframes
This decorator allows a user to run python functions in Airflow but with the huge benefit that SQL tables
will automatically be turned into dataframes and resulting dataframes can automatically used in astro.sql functions
"""
param_map = {
"conn_id": conn_id,
"database": database,
"schema": schema,
"identifiers_as_lower": identifiers_as_lower,
}
if task_id:
param_map["task_id"] = task_id
kwargs.update(
{
"conn_id": conn_id,
"database": database,
"schema": schema,
"identifiers_as_lower": identifiers_as_lower,
}
)
decorated_function: Callable[..., pd.DataFrame] = task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=DataframeOperator, # type: ignore
**param_map,
**kwargs,
)
return decorated_function
17 changes: 17 additions & 0 deletions tests/sql/operators/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,20 @@ def sample_pg(input_table: Table): # skipcq: PY-D0003
)
validate_result(pg_df)
test_utils.run_dag(sample_dag)


@pytest.mark.parametrize(
"kwargs",
[{"task_id": "task1", "queue": "new_1"}, {"queue": "new_2", "owner": "astro-sdk"}],
)
def test_pass_kwargs_to_base_operator(kwargs):
"""Test that kwargs passed to decorator are passed to BaseOperator"""

@aql.dataframe(**kwargs)
def sample_df_1(): # skipcq: PY-D0003
return pandas.DataFrame(
{"numbers": [1, 2, 3], "colors": ["red", "white", "blue"]}
)

task1 = sample_df_1()
assert all(getattr(task1.operator, k) == v for k, v in kwargs.items())

0 comments on commit ab220e4

Please sign in to comment.