From 8378e763d5b0235c4cb39e676d308c39c9aa1199 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 12 Jan 2023 12:46:35 +0000 Subject: [PATCH] Use cache to reduce redundant database calls (#1488) We were calling DB unnecessarily at various places to get the same info. This can be cached, this is evident from the below logs where it is reduced from 7 calls to 1 : *Before 1*: ``` [2022-12-23 02:09:16,425] {dag.py:3622} INFO - Running task top_five_animations [2022-12-23 02:09:16,438] {taskinstance.py:1511} INFO - Exporting the following env vars: ... [2022-12-23 02:09:16,439] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,440] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,440] {base_decorator.py:124} INFO - Returning table Table(name='top_animation', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=False, uri='astro://@?table=top_animation', extra={}) [2022-12-23 02:09:16,440] {base_decorator.py:124} INFO - Returning table Table(name='top_animation', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=False, uri='astro://@?table=top_animation', extra={}) [2022-12-23 02:09:16,441] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,442] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,445] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,450] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,450] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:09:16,461] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=top_five_animations, execution_date=20221223T020915, start_date=, end_date=20221223T020916 [2022-12-23 02:09:16,461] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=top_five_animations, execution_date=20221223T020915, start_date=, end_date=20221223T020916 [2022-12-23 02:09:16,464] {dag.py:3626} INFO - top_five_animations ran successfully! [2022-12-23 02:09:16,464] {dag.py:3629} INFO - ***************************************************** [2022-12-23 02:09:16,465] {dagrun.py:606} INFO - Marking run successful ``` *After 1*: ``` [2022-12-23 02:20:18,669] {dag.py:3622} INFO - Running task top_five_animations [2022-12-23 02:20:18,680] {taskinstance.py:1511} INFO - Exporting the following env vars: ... [2022-12-23 02:20:18,681] {base_decorator.py:124} INFO - Returning table Table(name='top_animation', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=False, uri='astro://@?table=top_animation', extra={}) [2022-12-23 02:20:18,681] {base_decorator.py:124} INFO - Returning table Table(name='top_animation', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=False, uri='astro://@?table=top_animation', extra={}) [2022-12-23 02:20:18,686] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 02:20:18,708] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=top_five_animations, execution_date=20221223T022017, start_date=, end_date=20221223T022018 [2022-12-23 02:20:18,708] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=top_five_animations, execution_date=20221223T022017, start_date=, end_date=20221223T022018 [2022-12-23 02:20:18,711] {dag.py:3626} INFO - top_five_animations ran successfully! [2022-12-23 02:20:18,711] {dag.py:3629} INFO - ***************************************************** [2022-12-23 02:20:18,713] {dagrun.py:606} INFO - Marking run successful [2022-12-23 02:20:18,715] {dagrun.py:657} INFO - DagRun Finished: dag_id=calculate_popular_movies, execution_date=2022-12-23T02:20:17.396648+00:00, run_id=manual__2022-12-23T02:20:17.396648+00:00, run_start_date=2022-12-23 02:20:17.396648+00:00, run_end_date=2022-12-23 02:20:18.713514+00:00, run_duration=1.316866, state=success, external_trigger=False, run_type=manual, data_interval_start=2022-12-23T02:20:17.396648+00:00, data_interval_end=2022-12-23T02:20:17.396648+00:00, dag_hash=None ``` *Before 2*: ``` [2022-12-23 01:55:54,386] {load_file.py:92} INFO - Loading https://raw.githubusercontent.com/astronomer/astro-sdk/main/tests/data/imdb_v2.csv into TempTable(name='_tmp_ztujoeesefaqclout728qnyomrc96suvgsntxnen11z4n40ia9wd99roe', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=True) ... [2022-12-23 01:55:54,388] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:55:54,393] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:55:54,499] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:55:54,507] {base.py:499} INFO - Loading file(s) with Pandas... [2022-12-23 01:55:54,606] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:55:54,658] {load_file.py:124} INFO - Completed loading the data into TempTable(name='_tmp_ztujoeesefaqclout728qnyomrc96suvgsntxnen11z4n40ia9wd99roe', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=True). [2022-12-23 01:55:54,663] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=imdb_movies, execution_date=20221223T015554, start_date=, end_date=20221223T015554 ``` *After 2*: ``` [2022-12-23 01:56:37,620] {load_file.py:92} INFO - Loading https://raw.githubusercontent.com/astronomer/astro-sdk/main/tests/data/imdb_v2.csv into TempTable(name='_tmp_rnagpj5gmps5a3oplvlwvlmv6u918qw21inanpxg2j56lo725mrzgp9jo', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=True) ... [2022-12-23 01:56:37,621] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:56:37,625] {base.py:73} INFO - Using connection ID 'sqlite_default' for task execution. [2022-12-23 01:56:37,730] {base.py:501} INFO - Loading file(s) with Pandas... [2022-12-23 01:56:37,881] {load_file.py:124} INFO - Completed loading the data into TempTable(name='_tmp_rnagpj5gmps5a3oplvlwvlmv6u918qw21inanpxg2j56lo725mrzgp9jo', conn_id='sqlite_default', metadata=Metadata(schema=None, database=None), columns=[], temp=True). [2022-12-23 01:56:37,886] {taskinstance.py:1322} INFO - Marking task as SUCCESS. dag_id=calculate_popular_movies, task_id=imdb_movies, execution_date=20221223T015637, start_date=, end_date=20221223T015637 ``` --- .../pre_commit_context_typing_compat.py | 2 +- python-sdk/pyproject.toml | 3 +- python-sdk/src/astro/databases/__init__.py | 16 ++++--- .../src/astro/databases/aws/redshift.py | 5 ++- python-sdk/src/astro/databases/base.py | 5 ++- .../src/astro/databases/google/bigquery.py | 5 ++- python-sdk/src/astro/databases/postgres.py | 3 +- python-sdk/src/astro/databases/snowflake.py | 3 +- python-sdk/src/astro/databases/sqlite.py | 5 ++- python-sdk/src/astro/files/operators/files.py | 2 +- python-sdk/src/astro/sql/operators/append.py | 2 +- .../src/astro/sql/operators/base_decorator.py | 2 +- python-sdk/src/astro/sql/operators/cleanup.py | 2 +- .../data_validations/check_column.py | 2 +- .../src/astro/sql/operators/dataframe.py | 2 +- python-sdk/src/astro/sql/operators/drop.py | 2 +- .../src/astro/sql/operators/export_to_file.py | 2 +- .../src/astro/sql/operators/load_file.py | 2 +- python-sdk/src/astro/sql/operators/merge.py | 2 +- python-sdk/src/astro/sql/operators/raw_sql.py | 2 +- .../src/astro/sql/operators/transform.py | 6 ++- python-sdk/src/astro/utils/compat/__init__.py | 0 .../src/astro/utils/compat/functools.py | 18 ++++++++ .../{typing_compat.py => compat/typing.py} | 0 python-sdk/src/astro/utils/table.py | 2 +- python-sdk/tests/databases/test_bigquery.py | 19 +++++++++ python-sdk/tests/databases/test_snowflake.py | 30 ++++++++++++- .../databases/test_bigquery.py | 33 --------------- .../databases/test_snowflake.py | 42 ------------------- .../tests_integration/sql/test_table.py | 12 +++--- 30 files changed, 118 insertions(+), 113 deletions(-) create mode 100644 python-sdk/src/astro/utils/compat/__init__.py create mode 100644 python-sdk/src/astro/utils/compat/functools.py rename python-sdk/src/astro/utils/{typing_compat.py => compat/typing.py} (100%) diff --git a/python-sdk/dev/scripts/pre_commit_context_typing_compat.py b/python-sdk/dev/scripts/pre_commit_context_typing_compat.py index 07e1dde36..f4b454959 100755 --- a/python-sdk/dev/scripts/pre_commit_context_typing_compat.py +++ b/python-sdk/dev/scripts/pre_commit_context_typing_compat.py @@ -12,7 +12,7 @@ SOURCES_ROOT = Path(__file__).parents[2] ASTRO_ROOT = SOURCES_ROOT / "src" / "astro" -TYPING_COMPAT_PATH = "python-sdk/src/astro/utils/typing_compat.py" +TYPING_COMPAT_PATH = "python-sdk/src/astro/utils/compat/typing.py" class ImportCrawler(NodeVisitor): diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 88c9e3cc6..8eee2e76c 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -24,7 +24,8 @@ dependencies = [ "python-frontmatter", "smart-open", "SQLAlchemy>=1.3.18", - "apache-airflow-providers-common-sql" + "apache-airflow-providers-common-sql", + "cached_property>=1.5.0;python_version<='3.7'" ] keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"] diff --git a/python-sdk/src/astro/databases/__init__.py b/python-sdk/src/astro/databases/__init__.py index 80ac14f2d..19462f7b4 100644 --- a/python-sdk/src/astro/databases/__init__.py +++ b/python-sdk/src/astro/databases/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from astro.options import LoadOptionsList +from astro.utils.compat.functools import cache from astro.utils.path import get_class_name, get_dict_with_module_names_to_dot_notations if TYPE_CHECKING: # pragma: no cover @@ -34,13 +35,18 @@ def create_database( :param conn_id: Database connection ID in Airflow :param table: (optional) The Table object """ - from airflow.hooks.base import BaseHook - - conn_type = BaseHook.get_connection(conn_id).conn_type - module_path = CONN_TYPE_TO_MODULE_PATH[conn_type] - module = importlib.import_module(module_path) + module = importlib.import_module(_get_conn(conn_id)) class_name = get_class_name(module_ref=module, suffix="Database") database_class = getattr(module, class_name) load_options = load_options_list and load_options_list.get(database_class) database: BaseDatabase = database_class(conn_id, table, load_options=load_options) return database + + +@cache +def _get_conn(conn_id: str) -> str: + from airflow.hooks.base import BaseHook + + conn_type = BaseHook.get_connection(conn_id).conn_type + module_path = CONN_TYPE_TO_MODULE_PATH[conn_type] + return module_path diff --git a/python-sdk/src/astro/databases/aws/redshift.py b/python-sdk/src/astro/databases/aws/redshift.py index ca5e625cf..3099e494a 100644 --- a/python-sdk/src/astro/databases/aws/redshift.py +++ b/python-sdk/src/astro/databases/aws/redshift.py @@ -34,6 +34,7 @@ from astro.options import LoadOptions from astro.settings import REDSHIFT_SCHEMA from astro.table import BaseTable, Metadata, Table +from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = RedshiftSQLHook.default_conn_name NATIVE_PATHS_SUPPORTED_FILE_TYPES = { @@ -89,7 +90,7 @@ def __init__( def sql_type(self): return "redshift" - @property + @cached_property def hook(self) -> RedshiftSQLHook: """Retrieve Airflow hook to interface with the Redshift database.""" kwargs = {} @@ -100,7 +101,7 @@ def hook(self) -> RedshiftSQLHook: kwargs.update({"schema": self.table.metadata.database}) return RedshiftSQLHook(redshift_conn_id=self.conn_id, use_legacy_sql=False, **kwargs) - @property + @cached_property def sqlalchemy_engine(self) -> Engine: """Return SQAlchemy engine.""" uri = self.hook.get_uri() diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index 3efbefbe9..e35e9319f 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -36,6 +36,7 @@ from astro.options import LoadOptions from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK, LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA from astro.table import BaseTable, Metadata +from astro.utils.compat.functools import cached_property class BaseDatabase(ABC): @@ -85,7 +86,7 @@ def __repr__(self): def sql_type(self): raise NotImplementedError - @property + @cached_property def hook(self) -> DbApiHook: """Return an instance of the database-specific Airflow hook.""" raise NotImplementedError @@ -95,7 +96,7 @@ def connection(self) -> sqlalchemy.engine.base.Connection: """Return a Sqlalchemy connection object for the given database.""" return self.sqlalchemy_engine.connect() - @property + @cached_property def sqlalchemy_engine(self) -> sqlalchemy.engine.base.Engine: """Return Sqlalchemy engine.""" return self.hook.get_sqlalchemy_engine() # type: ignore[no-any-return] diff --git a/python-sdk/src/astro/databases/google/bigquery.py b/python-sdk/src/astro/databases/google/bigquery.py index eb20e0e26..9779a63e6 100644 --- a/python-sdk/src/astro/databases/google/bigquery.py +++ b/python-sdk/src/astro/databases/google/bigquery.py @@ -48,6 +48,7 @@ from astro.options import LoadOptions from astro.settings import BIGQUERY_SCHEMA, BIGQUERY_SCHEMA_LOCATION from astro.table import BaseTable, Metadata +from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = BigQueryHook.default_conn_name NATIVE_PATHS_SUPPORTED_FILE_TYPES = { @@ -119,12 +120,12 @@ def __init__( def sql_type(self) -> str: return "bigquery" - @property + @cached_property def hook(self) -> BigQueryHook: """Retrieve Airflow hook to interface with the BigQuery database.""" return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION) - @property + @cached_property def sqlalchemy_engine(self) -> Engine: """Return SQAlchemy engine.""" uri = self.hook.get_uri() diff --git a/python-sdk/src/astro/databases/postgres.py b/python-sdk/src/astro/databases/postgres.py index f472b00dc..4822c88bf 100644 --- a/python-sdk/src/astro/databases/postgres.py +++ b/python-sdk/src/astro/databases/postgres.py @@ -15,6 +15,7 @@ from astro.options import LoadOptions from astro.settings import POSTGRES_SCHEMA from astro.table import BaseTable, Metadata +from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = PostgresHook.default_conn_name @@ -43,7 +44,7 @@ def __init__( def sql_type(self) -> str: return "postgresql" - @property + @cached_property def hook(self) -> PostgresHook: """Retrieve Airflow hook to interface with the Postgres database.""" conn = PostgresHook(postgres_conn_id=self.conn_id).get_connection(self.conn_id) diff --git a/python-sdk/src/astro/databases/snowflake.py b/python-sdk/src/astro/databases/snowflake.py index 4412ae7ac..5abc8f794 100644 --- a/python-sdk/src/astro/databases/snowflake.py +++ b/python-sdk/src/astro/databases/snowflake.py @@ -41,6 +41,7 @@ from astro.options import SnowflakeLoadOptions from astro.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA from astro.table import BaseTable, Metadata +from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = SnowflakeHook.default_conn_name @@ -267,7 +268,7 @@ def __init__( raise ValueError("Error: Requires a SnowflakeLoadOptions") self.load_options: SnowflakeLoadOptions | None = load_options - @property + @cached_property def hook(self) -> SnowflakeHook: """Retrieve Airflow hook to interface with the snowflake database.""" kwargs = {} diff --git a/python-sdk/src/astro/databases/sqlite.py b/python-sdk/src/astro/databases/sqlite.py index 5c8c4f31f..bc80ef165 100644 --- a/python-sdk/src/astro/databases/sqlite.py +++ b/python-sdk/src/astro/databases/sqlite.py @@ -11,6 +11,7 @@ from astro.databases.base import BaseDatabase from astro.options import LoadOptions from astro.table import BaseTable, Metadata +from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = SqliteHook.default_conn_name @@ -35,12 +36,12 @@ def __init__( def sql_type(self) -> str: return "sqlite" - @property + @cached_property def hook(self) -> SqliteHook: """Retrieve Airflow hook to interface with the Sqlite database.""" return SqliteHook(sqlite_conn_id=self.conn_id) - @property + @cached_property def sqlalchemy_engine(self) -> Engine: """Return SQAlchemy engine.""" # Airflow uses sqlite3 library and not SqlAlchemy for SqliteHook diff --git a/python-sdk/src/astro/files/operators/files.py b/python-sdk/src/astro/files/operators/files.py index c962103fd..990db705f 100644 --- a/python-sdk/src/astro/files/operators/files.py +++ b/python-sdk/src/astro/files/operators/files.py @@ -7,7 +7,7 @@ from astro.files.base import File from astro.files.locations import create_file_location -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class ListFileOperator(BaseOperator): diff --git a/python-sdk/src/astro/sql/operators/append.py b/python-sdk/src/astro/sql/operators/append.py index 4f43eb1c3..c87add38c 100644 --- a/python-sdk/src/astro/sql/operators/append.py +++ b/python-sdk/src/astro/sql/operators/append.py @@ -9,7 +9,7 @@ from astro.databases import create_database from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class AppendOperator(AstroSQLBaseOperator): diff --git a/python-sdk/src/astro/sql/operators/base_decorator.py b/python-sdk/src/astro/sql/operators/base_decorator.py index 5dd44d246..b12c46e5c 100644 --- a/python-sdk/src/astro/sql/operators/base_decorator.py +++ b/python-sdk/src/astro/sql/operators/base_decorator.py @@ -13,8 +13,8 @@ from astro.databases.base import BaseDatabase from astro.sql.operators.upstream_task_mixin import UpstreamTaskMixin from astro.table import BaseTable, Table +from astro.utils.compat.typing import Context from astro.utils.table import find_first_table -from astro.utils.typing_compat import Context class BaseSQLDecoratedOperator(UpstreamTaskMixin, DecoratedOperator): diff --git a/python-sdk/src/astro/sql/operators/cleanup.py b/python-sdk/src/astro/sql/operators/cleanup.py index 084289496..7cd6d4d16 100644 --- a/python-sdk/src/astro/sql/operators/cleanup.py +++ b/python-sdk/src/astro/sql/operators/cleanup.py @@ -26,7 +26,7 @@ from astro.sql.operators.dataframe import DataframeOperator from astro.sql.operators.load_file import LoadFileOperator from astro.table import BaseTable, TempTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context OPERATOR_CLASSES_WITH_TABLE_OUTPUT = ( DataframeOperator, diff --git a/python-sdk/src/astro/sql/operators/data_validations/check_column.py b/python-sdk/src/astro/sql/operators/data_validations/check_column.py index a2b63b68f..afc7bbe85 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/check_column.py +++ b/python-sdk/src/astro/sql/operators/data_validations/check_column.py @@ -8,7 +8,7 @@ from astro.databases import create_database from astro.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class ColumnCheckOperator(SQLColumnCheckOperator): diff --git a/python-sdk/src/astro/sql/operators/dataframe.py b/python-sdk/src/astro/sql/operators/dataframe.py index f4d4e2a4d..e4c482ec9 100644 --- a/python-sdk/src/astro/sql/operators/dataframe.py +++ b/python-sdk/src/astro/sql/operators/dataframe.py @@ -21,9 +21,9 @@ from astro.files import File from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.sql.table import BaseTable, Table +from astro.utils.compat.typing import Context from astro.utils.dataframe import convert_columns_names_capitalization from astro.utils.table import find_first_table -from astro.utils.typing_compat import Context def _get_dataframe( diff --git a/python-sdk/src/astro/sql/operators/drop.py b/python-sdk/src/astro/sql/operators/drop.py index 3ddbef44c..0d9089f8a 100644 --- a/python-sdk/src/astro/sql/operators/drop.py +++ b/python-sdk/src/astro/sql/operators/drop.py @@ -8,7 +8,7 @@ from astro.databases import create_database from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class DropTableOperator(AstroSQLBaseOperator): diff --git a/python-sdk/src/astro/sql/operators/export_to_file.py b/python-sdk/src/astro/sql/operators/export_to_file.py index 6e48a487e..4a12be950 100644 --- a/python-sdk/src/astro/sql/operators/export_to_file.py +++ b/python-sdk/src/astro/sql/operators/export_to_file.py @@ -12,7 +12,7 @@ from astro.files import File from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.table import BaseTable, Table -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class ExportToFileOperator(AstroSQLBaseOperator): diff --git a/python-sdk/src/astro/sql/operators/load_file.py b/python-sdk/src/astro/sql/operators/load_file.py index 2520365c8..1c3451b8a 100644 --- a/python-sdk/src/astro/sql/operators/load_file.py +++ b/python-sdk/src/astro/sql/operators/load_file.py @@ -18,7 +18,7 @@ from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class LoadFileOperator(AstroSQLBaseOperator): diff --git a/python-sdk/src/astro/sql/operators/merge.py b/python-sdk/src/astro/sql/operators/merge.py index 0d606df1e..2afe39710 100644 --- a/python-sdk/src/astro/sql/operators/merge.py +++ b/python-sdk/src/astro/sql/operators/merge.py @@ -10,7 +10,7 @@ from astro.databases import create_database from astro.sql.operators.base_operator import AstroSQLBaseOperator from astro.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class MergeOperator(AstroSQLBaseOperator): diff --git a/python-sdk/src/astro/sql/operators/raw_sql.py b/python-sdk/src/astro/sql/operators/raw_sql.py index f7d47b9bd..77d756e46 100644 --- a/python-sdk/src/astro/sql/operators/raw_sql.py +++ b/python-sdk/src/astro/sql/operators/raw_sql.py @@ -20,7 +20,7 @@ from astro import settings from astro.exceptions import IllegalLoadToDatabaseException from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class RawSQLOperator(BaseSQLDecoratedOperator): diff --git a/python-sdk/src/astro/sql/operators/transform.py b/python-sdk/src/astro/sql/operators/transform.py index 0209fd5a3..5bd7898c9 100644 --- a/python-sdk/src/astro/sql/operators/transform.py +++ b/python-sdk/src/astro/sql/operators/transform.py @@ -14,7 +14,7 @@ from sqlalchemy.sql.functions import Function from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context class TransformOperator(BaseSQLDecoratedOperator): @@ -60,7 +60,9 @@ def execute(self, context: Context): parameters=self.parameters, ) # TODO: remove pushing to XCom once we update the airflow version. - context["ti"].xcom_push(key="output_table_row_count", value=str(self.output_table.row_count)) + context["ti"].xcom_push( + key="output_table_row_count", value=str(self.database_impl.row_count(self.output_table)) + ) context["ti"].xcom_push(key="output_table_conn_id", value=str(self.output_table.conn_id)) return self.output_table diff --git a/python-sdk/src/astro/utils/compat/__init__.py b/python-sdk/src/astro/utils/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python-sdk/src/astro/utils/compat/functools.py b/python-sdk/src/astro/utils/compat/functools.py new file mode 100644 index 000000000..8b3bac3b1 --- /dev/null +++ b/python-sdk/src/astro/utils/compat/functools.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from cached_property import cached_property + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache + + cache = lru_cache(maxsize=None) + + +__all__ = ["cache", "cached_property"] diff --git a/python-sdk/src/astro/utils/typing_compat.py b/python-sdk/src/astro/utils/compat/typing.py similarity index 100% rename from python-sdk/src/astro/utils/typing_compat.py rename to python-sdk/src/astro/utils/compat/typing.py diff --git a/python-sdk/src/astro/utils/table.py b/python-sdk/src/astro/utils/table.py index 2077ea4d5..0b4b33955 100644 --- a/python-sdk/src/astro/utils/table.py +++ b/python-sdk/src/astro/utils/table.py @@ -6,7 +6,7 @@ from airflow.models.xcom_arg import XComArg from astro.sql.table import BaseTable -from astro.utils.typing_compat import Context +from astro.utils.compat.typing import Context def _have_same_conn_id(tables: list[BaseTable]) -> bool: diff --git a/python-sdk/tests/databases/test_bigquery.py b/python-sdk/tests/databases/test_bigquery.py index b322ac526..05c24d97e 100644 --- a/python-sdk/tests/databases/test_bigquery.py +++ b/python-sdk/tests/databases/test_bigquery.py @@ -10,6 +10,7 @@ from astro import settings from astro.databases.google.bigquery import BigqueryDatabase, S3ToBigqueryDataTransfer +from astro.exceptions import DatabaseCustomError from astro.files import File from astro.table import TEMP_PREFIX, Metadata, Table @@ -139,3 +140,21 @@ def mock_get_dataset(dataset_id): db = BigqueryDatabase(table=source_table, conn_id="test_conn") assert db.populate_table_metadata(input_table) == returned_table + + +@mock.patch("astro.databases.google.bigquery.BigqueryDatabase.hook") +def test_get_project_id_raise_exception(mock_hook): + """ + Test loading on files to bigquery natively for fallback without fallback + gracefully for wrong file location. + """ + + class CustomAttributeError: + def __str__(self): + raise AttributeError + + database = BigqueryDatabase() + mock_hook.project_id = CustomAttributeError() + + with pytest.raises(DatabaseCustomError): + database.get_project_id(target_table=Table()) diff --git a/python-sdk/tests/databases/test_snowflake.py b/python-sdk/tests/databases/test_snowflake.py index dfc5c06da..ce26ecde4 100644 --- a/python-sdk/tests/databases/test_snowflake.py +++ b/python-sdk/tests/databases/test_snowflake.py @@ -6,10 +6,11 @@ import pytest from astro.databases.snowflake import SnowflakeDatabase, SnowflakeFileFormat, SnowflakeStage +from astro.exceptions import DatabaseCustomError from astro.files import File from astro.options import LoadOptions, SnowflakeLoadOptions from astro.settings import SNOWFLAKE_STORAGE_INTEGRATION_AMAZON, SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE -from astro.table import Table +from astro.table import Metadata, Table DEFAULT_CONN_ID = "snowflake_default" CUSTOM_CONN_ID = "snowflake_conn" @@ -78,6 +79,33 @@ def test_use_quotes(cols_eval): assert SnowflakeDatabase.use_quotes(cols_eval["cols"]) == cols_eval["expected_result"] +@mock.patch("astro.databases.snowflake.SnowflakeDatabase.hook") +@mock.patch("astro.databases.snowflake.SnowflakeDatabase.create_stage") +def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable_native_fallback( + mock_stage, mock_hook +): + mock_hook.run.side_effect = [ + ValueError, # 1st run call copies the data + None, # 2nd run call drops the stage + ] + mock_stage.return_value = SnowflakeStage( + name="mock_stage", + url="gcs://bucket/prefix", + metadata=Metadata(database="SNOWFLAKE_DATABASE", schema="SNOWFLAKE_SCHEMA"), + ) + database = SnowflakeDatabase() + with pytest.raises(DatabaseCustomError): + database.load_file_to_table_natively_with_fallback( + source_file=File(str(pathlib.Path(CWD.parent, "data/sample.csv"))), + target_table=Table(), + ) + mock_hook.run.assert_has_calls( + [ + mock.call(f"DROP STAGE IF EXISTS {mock_stage.return_value.qualified_name};", autocommit=True), + ] + ) + + def test_snowflake_load_options(): path = str(CWD) + "/../../data/homes_main.csv" database = SnowflakeDatabase( diff --git a/python-sdk/tests_integration/databases/test_bigquery.py b/python-sdk/tests_integration/databases/test_bigquery.py index 53c395d41..645102023 100644 --- a/python-sdk/tests_integration/databases/test_bigquery.py +++ b/python-sdk/tests_integration/databases/test_bigquery.py @@ -234,39 +234,6 @@ def test_load_file_to_table_natively_for_fallback_wrong_file_location_with_enabl ) -@pytest.mark.integration -@pytest.mark.parametrize( - "database_table_fixture", - [ - { - "database": Database.BIGQUERY, - "table": Table(metadata=Metadata(schema=SCHEMA)), - }, - ], - indirect=True, - ids=["bigquery"], -) -@mock.patch("astro.databases.google.bigquery.BigqueryDatabase.hook") -def test_get_project_id_raise_exception( - mock_hook, - database_table_fixture, -): - """ - Test loading on files to bigquery natively for fallback without fallback - gracefully for wrong file location. - """ - - class CustomAttibuteError: - def __str__(self): - raise AttributeError - - mock_hook.project_id = CustomAttibuteError() - database, target_table = database_table_fixture - - with pytest.raises(DatabaseCustomError): - database.get_project_id(target_table=target_table) - - @pytest.mark.integration @pytest.mark.parametrize( "database_table_fixture", diff --git a/python-sdk/tests_integration/databases/test_snowflake.py b/python-sdk/tests_integration/databases/test_snowflake.py index 36dba67e2..7a36b9ff4 100644 --- a/python-sdk/tests_integration/databases/test_snowflake.py +++ b/python-sdk/tests_integration/databases/test_snowflake.py @@ -2,7 +2,6 @@ import os import pathlib from unittest import mock -from unittest.mock import call import pandas as pd import pytest @@ -232,47 +231,6 @@ def test_load_file_from_cloud_to_table(database_table_fixture): test_utils.assert_dataframes_are_equal(df, expected) -@pytest.mark.integration -@pytest.mark.parametrize( - "database_table_fixture", - [ - { - "database": Database.SNOWFLAKE, - "table": Table(metadata=Metadata(schema=SCHEMA)), - }, - ], - indirect=True, - ids=["snowflake"], -) -@mock.patch("astro.databases.snowflake.SnowflakeDatabase.hook") -@mock.patch("astro.databases.snowflake.SnowflakeDatabase.create_stage") -def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable_native_fallback( - mock_stage, mock_hook, database_table_fixture -): - """Test loading on files to snowflake natively for fallback raise exception.""" - mock_hook.run.side_effect = [ - ValueError, # 1st run call copies the data - None, # 2nd run call drops the stage - ] - mock_stage.return_value = SnowflakeStage( - name="mock_stage", - url="gcs://bucket/prefix", - metadata=Metadata(database="SNOWFLAKE_DATABASE", schema="SNOWFLAKE_SCHEMA"), - ) - database, target_table = database_table_fixture - filepath = str(pathlib.Path(CWD.parent, "data/sample.csv")) - with pytest.raises(DatabaseCustomError): - database.load_file_to_table_natively_with_fallback( - source_file=File(filepath), - target_table=target_table, - ) - mock_hook.run.assert_has_calls( - [ - call(f"DROP STAGE IF EXISTS {mock_stage.return_value.qualified_name};", autocommit=True), - ] - ) - - @pytest.mark.integration @pytest.mark.parametrize( "database_table_fixture", diff --git a/python-sdk/tests_integration/sql/test_table.py b/python-sdk/tests_integration/sql/test_table.py index 1daa1bfe4..7f2a7abc1 100644 --- a/python-sdk/tests_integration/sql/test_table.py +++ b/python-sdk/tests_integration/sql/test_table.py @@ -13,7 +13,7 @@ [ ( Connection( - conn_id="test_conn", conn_type="gcpbigquery", extra={"project": "astronomer-dag-authoring"} + conn_id="test_bq", conn_type="gcpbigquery", extra={"project": "astronomer-dag-authoring"} ), "astronomer-dag-authoring.dataset.test_tb", "bigquery", @@ -21,7 +21,7 @@ ), ( Connection( - conn_id="test_conn", + conn_id="test_redshift", conn_type="redshift", schema="astro", host="local", @@ -35,7 +35,7 @@ ), ( Connection( - conn_id="test_conn", + conn_id="test_pg", conn_type="postgres", login="postgres", password="postgres", @@ -48,7 +48,7 @@ ), ( Connection( - conn_id="test_conn", + conn_id="test_snow", conn_type="snowflake", host="local", port=443, @@ -68,7 +68,7 @@ "snowflake://astro-sdk/TEST_ASTRO.ci.test_tb", ), ( - Connection(conn_id="test_conn", conn_type="sqlite", host="/tmp/sqlite.db"), + Connection(conn_id="test_sqlite", conn_type="sqlite", host="/tmp/sqlite.db"), "/tmp/sqlite.db.test_tb", f"file://{socket.gethostbyname(socket.gethostname())}:22", f"file://{socket.gethostbyname(socket.gethostname())}:22/tmp/sqlite.db.test_tb", @@ -83,7 +83,7 @@ def test_openlineage_dataset(mock_get_connection, gcp_cred, connection, name, na """ mock_get_connection.return_value = connection gcp_cred.return_value = "astronomer-dag-authoring", "astronomer-dag-authoring" - tb = Table(conn_id="test_conn", name="test_tb", metadata=Metadata(schema="dataset")) + tb = Table(conn_id=connection.conn_id, name="test_tb", metadata=Metadata(schema="dataset")) assert tb.openlineage_dataset_name() == name assert tb.openlineage_dataset_namespace() == namespace