Skip to content

Commit

Permalink
Allow templating fields in load and save (#109)
Browse files Browse the repository at this point in the history
* Allow templating fields in load and save

Allows users to template paths and input/output tables when loading or
saving files

* nit

* add template env var

* fix save
  • Loading branch information
dimberman authored Feb 16, 2022
1 parent 00cc5ed commit fc9d70d
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 12 deletions.
1 change: 1 addition & 0 deletions .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ SNOWFLAKE_ROLE=<add snowflake role here>
SNOWFLAKE_DATABASE=<add snowflake database here>
SNOWFLAKE_WAREHOUSE=<add snowflake warehouse here>
AIRFLOW__ASTRO__CONN_AWS_DEFAULT=<add aws api key here>:<add aws secret key here>@
AIRFLOW_VAR_FOO=templated_file_name
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ jobs:
SNOWFLAKE_REGION: us-east-1
SNOWFLAKE_ROLE: TRANSFORMER
AIRFLOW__CORE__ENABLE_XCOM_PICKLING: True
AIRFLOW_VAR_FOO: templated_file_name

6 changes: 6 additions & 0 deletions src/astro/sql/operators/agnostic_load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class AgnosticLoadFile(BaseOperator):
:type output_conn_id: str
"""

template_fields = (
"output_table",
"file_conn_id",
"path",
)

def __init__(
self,
path,
Expand Down
19 changes: 10 additions & 9 deletions src/astro/sql/operators/agnostic_save_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@
class SaveFile(BaseOperator):
"""Write SQL table to csv/parquet on local/S3/GCS.
:param input_table: Table to convert to file
:type input_table: Table
:param output_file_path: Path and name of table to create.
:type output_file_path: str
:param table: Input table name.
:type table: str
:param input_conn_id: Database connection id.
:type input_conn_id: str
:param output_conn_id: File system connection id (if S3 or GCS).
:type output_conn_id: str
:param overwrite: Overwrite file if exists. Default False.
Expand All @@ -51,18 +49,23 @@ class SaveFile(BaseOperator):
:type output_file_format: str
"""

template_fields = (
"input_table",
"output_file_path",
"output_conn_id",
"output_file_format",
)

def __init__(
self,
table="",
output_file_path="",
input_table: Table = None,
output_file_path="",
output_conn_id=None,
output_file_format="csv",
overwrite=None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.table = table
self.output_file_path = output_file_path
self.input_table = input_table
self.output_conn_id = output_conn_id
Expand Down Expand Up @@ -160,7 +163,6 @@ def create_table_name(context):

def save_file(
output_file_path,
table=None,
input_table=None,
output_conn_id=None,
overwrite=False,
Expand Down Expand Up @@ -193,7 +195,6 @@ def save_file(
return SaveFile(
task_id=task_id,
output_file_path=output_file_path,
table=table,
input_table=input_table,
output_conn_id=output_conn_id,
overwrite=overwrite,
Expand Down
9 changes: 9 additions & 0 deletions src/astro/sql/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@


class Table:
template_fields = (
"table_name",
"conn_id",
"database",
"schema",
"warehouse",
"role",
)

def __init__(
self,
table_name="",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 7 additions & 3 deletions tests/operators/test_agnostic_load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,11 @@ def create_task_parameters(database_name, file_type):
sql_server_params["conn_id"] = conn_id_value

task_params = {
"path": str(CWD) + f"/../data/sample.{file_type}",
"path": str(CWD) + "/../data/{{ var.value.foo }}/sample." + file_type,
"file_conn_id": "",
"output_table": Table(table_name=OUTPUT_TABLE_NAME, **sql_server_params),
"output_table": Table(
table_name=OUTPUT_TABLE_NAME + "_{{ var.value.foo }}", **sql_server_params
),
}
return task_params

Expand All @@ -531,7 +533,9 @@ def test_load_file(sample_dag, sql_server, file_type):

test_utils.create_and_run_task(sample_dag, load_file, (), task_params)

df = sql_hook.get_pandas_df(f"SELECT * FROM {schema}.{OUTPUT_TABLE_NAME}")
df = sql_hook.get_pandas_df(
f"SELECT * FROM {schema}.{OUTPUT_TABLE_NAME}_templated_file_name"
)

assert len(df) == 3
expected = pd.DataFrame(
Expand Down

0 comments on commit fc9d70d

Please sign in to comment.