Skip to content

Commit

Permalink
Types: Unlock supporting timezone-aware DateTime fields
Browse files Browse the repository at this point in the history
Co-authored-by: Marios Trivyzas <[email protected]>
  • Loading branch information
amotl and matriv committed Jun 24, 2024
1 parent 877ebaa commit 833328b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

## Unreleased
- Added/reactivated documentation as `sqlalchemy-cratedb`
- Unlocked supporting timezone-aware `DateTime` fields

## 2024/06/13 0.37.0
- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying
Expand Down
5 changes: 0 additions & 5 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
CrateTypeCompiler,
CrateDDLCompiler
)
from crate.client.exceptions import TimezoneUnawareException
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
from .type import FloatVector, ObjectArray, ObjectType

Expand Down Expand Up @@ -113,14 +112,10 @@ def process(value):

class DateTime(sqltypes.DateTime):

TZ_ERROR_MSG = "Timezone aware datetime objects are not supported"

def bind_processor(self, dialect):
def process(value):
if value is not None:
assert isinstance(value, datetime)
if value.tzinfo is not None:
raise TimezoneUnawareException(DateTime.TZ_ERROR_MSG)
return value.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
return value
return process
Expand Down
82 changes: 78 additions & 4 deletions tests/datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from __future__ import absolute_import

from datetime import datetime, tzinfo, timedelta
import datetime as dt
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock

import pytest
import sqlalchemy as sa
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker

from sqlalchemy_cratedb import SA_VERSION, SA_1_4

Expand Down Expand Up @@ -87,9 +88,82 @@ def test_date_can_handle_datetime(self):
]
self.session.query(self.Character).first()

def test_date_cannot_handle_tz_aware_datetime(self):
def test_date_can_handle_tz_aware_datetime(self):
character = self.Character()
character.name = "Athur"
character.timestamp = datetime(2009, 5, 13, 19, 19, 30, tzinfo=CST())
self.session.add(character)
self.assertRaises(DBAPIError, self.session.commit)


Base = declarative_base()


class FooBar(Base):
__tablename__ = "foobar"
name = sa.Column(sa.String, primary_key=True)
date = sa.Column(sa.Date)
datetime = sa.Column(sa.DateTime)


@pytest.fixture
def session(cratedb_service):
engine = cratedb_service.database.engine
session = sessionmaker(bind=engine)()

Base.metadata.drop_all(engine, checkfirst=True)
Base.metadata.create_all(engine, checkfirst=True)
return session


@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3")
def test_datetime_notz(session):
"""
An integration test for `sa.Date` and `sa.DateTime`, not using timezones.
"""

# Insert record.
foo_item = FooBar(
name="foo",
date=dt.date(2009, 5, 13),
datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456),
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first()

# Compare outcome.
assert result["date"].year == 2009
assert result["datetime"].year == 2009
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000)
assert result["datetime"].tzinfo is None


@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3")
def test_datetime_tz(session):
"""
An integration test for `sa.Date` and `sa.DateTime`, now using timezones.
"""

# Insert record.
foo_item = FooBar(
name="foo",
date=dt.date(2009, 5, 13),
datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456, tzinfo=CST()),
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first()

# Compare outcome.
assert result["date"].year == 2009
assert result["datetime"].year == 2009
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000)
assert result["datetime"].tzinfo is None

0 comments on commit 833328b

Please sign in to comment.