Skip to content

Commit

Permalink
Release 1.5.3 (#1835)
Browse files Browse the repository at this point in the history
Bug fix:
- Support using SQL operators (`run_raw_sql`, `transform`, `dataframe`)
to convert a Pandas dataframe into a table when using a DuckDB in-memory
database. [#1831](#1833)
  • Loading branch information
tatiana authored Mar 8, 2023
1 parent a4abd1c commit fe4e83d
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 64 deletions.
6 changes: 6 additions & 0 deletions .github/ci-test-connections.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ connections:
schema:
login:
password:
- conn_id: duckdb_memory
conn_type: duckdb
host:
schema:
login:
password:
- conn_id: minio_conn
conn_type: aws
description: null
Expand Down
7 changes: 7 additions & 0 deletions python-sdk/docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 1.5.3

### Bug Fixes

- Support using SQL operators (`run_raw_sql`, `transform`, `dataframe`) to convert a Pandas dataframe into a table when using a DuckDB in-memory database. [#1831](https://github.com/astronomer/astro-sdk/pull/1833)


## 1.5.2

### Improvements
Expand Down
44 changes: 44 additions & 0 deletions python-sdk/example_dags/example_duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
pipeline_2
DAG auto-generated by Astro Cloud IDE.
"""

import pandas as pd
import pendulum
from airflow.decorators import dag

from astro import sql as aql
from astro.table import Table


@aql.dataframe(task_id="python_1")
def python_1_func():
return pd.DataFrame({"a": [1, 2, 3]})


@aql.run_raw_sql(
conn_id="duckdb_memory",
task_id="sql_duck",
handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys()),
)
def sql_duck_func(python_1: Table):
return """
SELECT * FROM {{python_1}}
"""


@dag(
schedule_interval="0 0 * * * *",
start_date=pendulum.from_format("2023-02-23", "YYYY-MM-DD").in_tz("UTC"),
)
def pipeline_2():
python_1 = python_1_func()

sql_duck = sql_duck_func(
python_1,
)

sql_duck << python_1 # skipcq: PYL-W0104


dag_obj = pipeline_2()
2 changes: 1 addition & 1 deletion python-sdk/src/astro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A decorator that allows users to run SQL queries natively in Airflow."""

__version__ = "1.5.2"
__version__ = "1.5.3"


# This is needed to allow Airflow to pick up specific metadata fields it needs
Expand Down
3 changes: 2 additions & 1 deletion python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self.conn_id = conn_id
self.sql: str | ClauseElement = ""
self.load_options = load_options
self.table = table

def __repr__(self):
return f'{self.__class__.__name__}(conn_id="{self.conn_id})'
Expand Down Expand Up @@ -613,7 +614,7 @@ def load_pandas_dataframe_to_table(

source_dataframe.to_sql(
self.get_table_qualified_name(target_table),
con=self.sqlalchemy_engine,
con=self.connection,
if_exists=if_exists,
chunksize=chunk_size,
method="multi",
Expand Down
9 changes: 9 additions & 0 deletions python-sdk/src/astro/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import socket

import sqlalchemy
from duckdb_provider.hooks.duckdb_hook import DuckDBHook
from sqlalchemy import MetaData as SqlaMetaData
from sqlalchemy.sql.schema import Table as SqlaTable
Expand Down Expand Up @@ -32,6 +33,14 @@ def __init__(
def sql_type(self) -> str:
return "duckdb"

# We are caching this property to persist the DuckDB in-memory connection, to avoid
# the problem described in
# https://github.com/astronomer/astro-sdk/issues/1831
@cached_property
def connection(self) -> sqlalchemy.engine.base.Connection: # skipcq PYL-W0236
"""Return a Sqlalchemy connection object for the given database."""
return self.sqlalchemy_engine.connect()

@cached_property
def hook(self) -> DuckDBHook:
"""Retrieve Airflow hook to interface with the DuckDB database."""
Expand Down
108 changes: 56 additions & 52 deletions python-sdk/src/astro/sql/operators/base_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from astro.databases.base import BaseDatabase
from astro.sql.operators.upstream_task_mixin import UpstreamTaskMixin
from astro.table import BaseTable, Table
from astro.utils.compat.functools import cached_property
from astro.utils.compat.typing import Context
from astro.utils.table import find_first_table

Expand All @@ -26,8 +27,6 @@ class BaseSQLDecoratedOperator(UpstreamTaskMixin, DecoratedOperator):
template_fields: Sequence[str] = ("sql", "parameters", "op_args", "op_kwargs")
template_ext: Sequence[str] = (".sql",)

database_impl: BaseDatabase

def __init__(
self,
conn_id: str | None = None,
Expand Down Expand Up @@ -97,6 +96,10 @@ def _resolve_xcom_op_args(self, context: Context) -> None:
args.append(item)
self.op_args = args # type: ignore

@cached_property
def database_impl(self) -> BaseDatabase:
return create_database(self.conn_id)

def _enrich_context(self, context: Context) -> Context:
"""
Prepare the sql and context for execution.
Expand Down Expand Up @@ -136,14 +139,14 @@ def _enrich_context(self, context: Context) -> Context:
):
raise ValueError("source and target table must belong to the same datasource")

self.database_impl = create_database(self.conn_id, first_table)
self.database_impl.table = first_table

# Find and load dataframes from op_arg and op_kwarg into Table
self.create_output_table_if_needed()
self.op_args = load_op_arg_dataframes_into_sql( # type: ignore
self.op_args = self.load_op_arg_dataframes_into_sql( # type: ignore
conn_id=self.conn_id, op_args=self.op_args, output_table=self.output_table # type: ignore
)
self.op_kwargs = load_op_kwarg_dataframes_into_sql(
self.op_kwargs = self.load_op_kwarg_dataframes_into_sql(
conn_id=self.conn_id,
op_kwargs=self.op_kwargs,
output_table=self.output_table,
Expand Down Expand Up @@ -359,51 +362,52 @@ def get_source_code(self, py_callable: Callable) -> str | None:
self.log.warning("Can't get source code facet of Operator {self.operator.task_id}")
return None

def load_op_arg_dataframes_into_sql(self, conn_id: str, op_args: tuple, output_table: BaseTable) -> tuple:
"""
Identify dataframes in op_args and load them to the table.
def load_op_arg_dataframes_into_sql(conn_id: str, op_args: tuple, output_table: BaseTable) -> tuple:
"""
Identify dataframes in op_args and load them to the table.
:param conn_id: Connection identifier to be used to load content to the target_table
:param op_args: user-defined decorator's kwargs
:param output_table: Similar table where the dataframe content will be written to
:return: New op_args, in which dataframes are replaced by tables
"""
final_args: list[Table | BaseTable] = []
database = create_database(conn_id=conn_id)
for arg in op_args:
if isinstance(arg, pd.DataFrame):
target_table = output_table.create_similar_table()
database.load_pandas_dataframe_to_table(source_dataframe=arg, target_table=target_table)
final_args.append(target_table)
elif isinstance(arg, BaseTable):
arg = database.populate_table_metadata(arg)
final_args.append(arg)
else:
final_args.append(arg)
return tuple(final_args)


def load_op_kwarg_dataframes_into_sql(conn_id: str, op_kwargs: dict, output_table: BaseTable) -> dict:
"""
Identify dataframes in op_kwargs and load them to a table.
:param conn_id: Connection identifier to be used to load content to the target_table
:param op_kwargs: user-defined decorator's kwargs
:param output_table: Similar table where the dataframe content will be written to
:return: New op_kwargs, in which dataframes are replaced by tables
"""
final_kwargs = {}
database = create_database(conn_id=conn_id, table=output_table)
for key, value in op_kwargs.items():
if isinstance(value, pd.DataFrame):
target_table = output_table.create_similar_table()
df_table = cast(BaseTable, target_table.create_similar_table())
database.load_pandas_dataframe_to_table(source_dataframe=value, target_table=df_table)
final_kwargs[key] = df_table
elif isinstance(value, BaseTable):
value = database.populate_table_metadata(value)
final_kwargs[key] = value
else:
final_kwargs[key] = value
return final_kwargs
:param conn_id: Connection identifier to be used to load content to the target_table
:param op_args: user-defined decorator's kwargs
:param output_table: Similar table where the dataframe content will be written to
:return: New op_args, in which dataframes are replaced by tables
"""
final_args: list[Table | BaseTable] = []
database = self.database_impl or create_database(conn_id=conn_id)
for arg in op_args:
if isinstance(arg, pd.DataFrame):
target_table = output_table.create_similar_table()
database.load_pandas_dataframe_to_table(source_dataframe=arg, target_table=target_table)
final_args.append(target_table)
elif isinstance(arg, BaseTable):
arg = database.populate_table_metadata(arg)
final_args.append(arg)
else:
final_args.append(arg)
return tuple(final_args)

def load_op_kwarg_dataframes_into_sql(
self, conn_id: str, op_kwargs: dict, output_table: BaseTable
) -> dict:
"""
Identify dataframes in op_kwargs and load them to a table.
:param conn_id: Connection identifier to be used to load content to the target_table
:param op_kwargs: user-defined decorator's kwargs
:param output_table: Similar table where the dataframe content will be written to
:return: New op_kwargs, in which dataframes are replaced by tables
"""
final_kwargs = {}
database = self.database_impl or create_database(conn_id=conn_id)
database.table = output_table
for key, value in op_kwargs.items():
if isinstance(value, pd.DataFrame):
target_table = output_table.create_similar_table()
df_table = cast(BaseTable, target_table.create_similar_table())
database.load_pandas_dataframe_to_table(source_dataframe=value, target_table=df_table)
final_kwargs[key] = df_table
elif isinstance(value, BaseTable):
value = database.populate_table_metadata(value)
final_kwargs[key] = value
else:
final_kwargs[key] = value
return final_kwargs
17 changes: 10 additions & 7 deletions python-sdk/tests/sql/operators/test_base_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import pandas as pd
import pytest

from astro.databases.sqlite import SqliteDatabase
from astro.sql import RawSQLOperator
from astro.sql.operators.base_decorator import (
BaseSQLDecoratedOperator,
load_op_arg_dataframes_into_sql,
load_op_kwarg_dataframes_into_sql,
)
from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator
from astro.table import BaseTable, Table


Expand All @@ -34,7 +31,10 @@ def test_load_op_arg_dataframes_into_sql():
df_1 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
df_2 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
op_args = (df_1, df_2, Table(conn_id="sqlite_default"), "str")
results = load_op_arg_dataframes_into_sql(
operator = BaseSQLDecoratedOperator(task_id="test", python_callable=lambda: 1)
operator.database_impl = SqliteDatabase()

results = operator.load_op_arg_dataframes_into_sql(
conn_id="sqlite_default", op_args=op_args, output_table=Table(conn_id="sqlite_default")
)

Expand All @@ -50,7 +50,10 @@ def test_load_op_kwarg_dataframes_into_sql():
df_1 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
df_2 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
op_kwargs = {"df_1": df_1, "df_2": df_2, "table": Table(conn_id="sqlite_default"), "some_str": "str"}
results = load_op_kwarg_dataframes_into_sql(

operator = BaseSQLDecoratedOperator(task_id="test", python_callable=lambda: 1)
operator.database_impl = SqliteDatabase()
results = operator.load_op_kwarg_dataframes_into_sql(
conn_id="sqlite_default", op_kwargs=op_kwargs, output_table=Table(conn_id="sqlite_default")
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def validate_table_exists(table: Table):
},
{
"database": Database.DUCKDB,
"table": Table(conn_id="redshift_conn"),
"table": Table(conn_id="duckdb_conn"),
"file": File(DEFAULT_FILEPATH),
},
],
Expand All @@ -144,13 +144,11 @@ def test_drop_table_without_table_metadata(database_table_fixture, sample_dag):
"""Test drop table operator for all databases."""
database, test_table = database_table_fixture
assert database.table_exists(test_table)

with sample_dag:
aql.drop_table(
table=test_table,
)
test_utils.run_dag(sample_dag)

assert not database.table_exists(test_table)


Expand Down

0 comments on commit fe4e83d

Please sign in to comment.