From 4fb90ac257088f8f9730b30e9feb8ba561827dad Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Tue, 16 Jan 2024 00:02:26 +0100 Subject: [PATCH] Types: Unlock supporting timezone-aware `DateTime` fields --- CHANGES.md | 1 + src/sqlalchemy_cratedb/dialect.py | 5 -- tests/datetime_test.py | 80 +++++++++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 9 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 3b65ced..32a3343 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## Unreleased - Added/reactivated documentation as `sqlalchemy-cratedb` +- Unlocking supporting timezone-aware `DateTime` fields ## 2024/06/13 0.37.0 - Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 43af2fc..8c6fe23 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -31,7 +31,6 @@ CrateTypeCompiler, CrateDDLCompiler ) -from crate.client.exceptions import TimezoneUnawareException from .sa_version import SA_VERSION, SA_1_4, SA_2_0 from .type import FloatVector, ObjectArray, ObjectType @@ -113,14 +112,10 @@ def process(value): class DateTime(sqltypes.DateTime): - TZ_ERROR_MSG = "Timezone aware datetime objects are not supported" - def bind_processor(self, dialect): def process(value): if value is not None: assert isinstance(value, datetime) - if value.tzinfo is not None: - raise TimezoneUnawareException(DateTime.TZ_ERROR_MSG) return value.strftime('%Y-%m-%dT%H:%M:%S.%fZ') return value return process diff --git a/tests/datetime_test.py b/tests/datetime_test.py index 53c30fc..5ed45a1 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -22,12 +22,13 @@ from __future__ import absolute_import from datetime import datetime, tzinfo, timedelta +import datetime as dt from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock +import pytest import sqlalchemy as sa -from sqlalchemy.exc import DBAPIError -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy_cratedb import SA_VERSION, SA_1_4 @@ -87,9 +88,80 @@ def test_date_can_handle_datetime(self): ] self.session.query(self.Character).first() - def test_date_cannot_handle_tz_aware_datetime(self): + def test_date_can_handle_tz_aware_datetime(self): character = self.Character() character.name = "Athur" character.timestamp = datetime(2009, 5, 13, 19, 19, 30, tzinfo=CST()) self.session.add(character) - self.assertRaises(DBAPIError, self.session.commit) + + +Base = declarative_base() + + +class FooBar(Base): + __tablename__ = "foobar" + name = sa.Column(sa.String, primary_key=True) + date = sa.Column(sa.Date) + datetime = sa.Column(sa.DateTime) + + +@pytest.fixture +def session(cratedb_service): + engine = cratedb_service.database.engine + session = sessionmaker(bind=engine)() + + Base.metadata.drop_all(engine, checkfirst=True) + Base.metadata.create_all(engine, checkfirst=True) + return session + + +def test_datetime_notz(session): + """ + An integration test for `sa.Date` and `sa.DateTime`, not using timezones. + """ + + # Insert record. + foo_item = FooBar( + name="foo", + date=dt.date(2009, 5, 13), + datetime=dt.datetime(2009, 5, 13, 19, 19, 30), + ) + session.add(foo_item) + session.commit() + session.execute(sa.text("REFRESH TABLE foobar")) + + # Query record. + result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first() + + # Compare outcome. + assert result["date"].year == 2009 + assert result["datetime"].year == 2009 + assert result["datetime"].tzname() is None + assert result["datetime"].timetz() == dt.time(19, 19, 30) + assert result["datetime"].tzinfo is None + + +def test_datetime_tz(session): + """ + An integration test for `sa.Date` and `sa.DateTime`, now using timezones. + """ + + # Insert record. + foo_item = FooBar( + name="foo", + date=dt.date(2009, 5, 13), + datetime=dt.datetime(2009, 5, 13, 19, 19, 30, tzinfo=CST()), + ) + session.add(foo_item) + session.commit() + session.execute(sa.text("REFRESH TABLE foobar")) + + # Query record. + result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first() + + # Compare outcome. + assert result["date"].year == 2009 + assert result["datetime"].year == 2009 + assert result["datetime"].tzname() is None + assert result["datetime"].timetz() == dt.time(19, 19, 30) + assert result["datetime"].tzinfo is None