Skip to content

Commit

Permalink
Change export_file to return File object #453
Browse files Browse the repository at this point in the history
Fix #454

In order to allow users to perform subsequent actions on an exported file (while maintaining a functional structure), we should allow the export_file function to return a file object. Request by @jlaneve .
  • Loading branch information
dimberman authored and tatiana committed Jun 13, 2022
1 parent 45dfe27 commit 220f704
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/astro/sql/operators/export_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.if_exists = if_exists
self.kwargs = kwargs

def execute(self, context: dict) -> None:
def execute(self, context: dict) -> File:
"""Write SQL table to csv/parquet on local/S3/GCS.
Infers SQL database type based on connection.
Expand All @@ -53,6 +53,7 @@ def execute(self, context: dict) -> None:
# Write file if overwrite == True or if file doesn't exist.
if self.if_exists == "replace" or not self.output_file.exists():
self.output_file.create_from_dataframe(df)
return self.output_file
else:
raise FileExistsError(f"{self.output_file.path} file already exists.")

Expand All @@ -66,7 +67,20 @@ def export_file(
) -> XComArg:
"""Convert SaveFile into a function. Returns XComArg.
Returns an XComArg object.
Returns an XComArg object of type File which matches the output_file parameter.
This will allow users to perform further actions with the exported file.
e.g.
with sample_dag:
table = aql.load_file(input_file=File(path=data_path), output_table=test_table)
exported_file = aql.export_file(
input_data=table,
output_file=File(path="/tmp/saved_df.csv"),
if_exists="replace",
)
res_df = aql.load_file(input_file=exported_file)
:param output_file: Path and conn_id
:param input_data: Input table / dataframe
Expand Down
37 changes: 37 additions & 0 deletions tests/sql/operators/test_export_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,43 @@ def test_save_temp_table_to_local(sample_dag, sql_server, test_table):
assert input_df.equals(output_df)


@pytest.mark.parametrize("sql_server", [Database.SQLITE.value], indirect=True)
@pytest.mark.parametrize(
"test_table",
[
{
"path": str(CWD) + "/../../data/homes.csv",
"load_table": True,
"param": {
"name": test_utils.get_table_name("test_stats_check_1"),
},
}
],
indirect=True,
ids=["temp_table"],
)
def test_save_returns_output_file(sample_dag, test_table, sql_server):
@aql.dataframe
def validate(df: pd.DataFrame):
assert not df.empty

data_path = str(CWD) + "/../../data/homes.csv"
with sample_dag:
table = aql.load_file(input_file=File(path=data_path), output_table=test_table)
file = aql.export_file(
input_data=table,
output_file=File(path="/tmp/saved_df.csv"),
if_exists="replace",
)
res_df = aql.load_file(input_file=file)
validate(res_df)
test_utils.run_dag(sample_dag)

output_df = pd.read_csv("/tmp/saved_df.csv")
input_df = pd.read_csv(data_path)
assert input_df.equals(output_df)


@pytest.mark.parametrize("sql_server", SUPPORTED_DATABASES, indirect=True)
@pytest.mark.parametrize(
"test_table",
Expand Down

0 comments on commit 220f704

Please sign in to comment.