diff --git a/python-sdk/src/astro/custom_backend/serializer.py b/python-sdk/src/astro/custom_backend/serializer.py index f7338a3d7..69e94f334 100644 --- a/python-sdk/src/astro/custom_backend/serializer.py +++ b/python-sdk/src/astro/custom_backend/serializer.py @@ -5,21 +5,27 @@ from json import JSONDecodeError from typing import Any -import airflow import numpy as np import pandas +import sqlalchemy +from packaging import version -if airflow.__version__ >= "2.3": +try: from sqlalchemy.engine.row import LegacyRow as SQLAlcRow -else: +except ImportError: from sqlalchemy.engine.result import RowProxy as SQLAlcRow + from astro.files import File from astro.table import Table, TempTable log = logging.getLogger("astro.utils.serializer") +def is_newer_sqlalchemy(): + return version(sqlalchemy.__version__) >= version.parse("1.4.0") + + def serialize(obj: Table | File | Any) -> dict | Any: # noqa """ Serialize astro SDK objects (tables, files and dataframes) into json safe dictionary @@ -52,7 +58,7 @@ def serialize(obj: Table | File | Any) -> dict | Any: # noqa "key_map": obj._keymap, # skipcq PYL-W021 "key_style": obj._key_style, # skipcq PYL-W021 } - if airflow.__version__ >= "2.3": + if is_newer_sqlalchemy(): serialized_obj["data"] = obj._data # skipcq PYL-W021 return serialized_obj @@ -90,7 +96,7 @@ def deserialize(obj: dict | str | list) -> Table | File | Any: # noqa log.debug("Found file dictionary %s, will attempt to deserialize", obj) return _deserialize_file(obj) elif obj["class"] == "SQLAlcRow": - if airflow.__version__ >= "2.3": + if is_newer_sqlalchemy(): return SQLAlcRow(None, None, obj["key_map"], obj["key_style"], obj["data"]) else: return SQLAlcRow(None, None, obj["key_map"], obj["key_style"])