From 1258b4a257c7528232f41452218d7512e7a600e6 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 24 Jun 2024 20:18:59 +0200 Subject: [PATCH] DateTime and more: Use `sqlalchemy_cratedb.dialect.DateTime` ... ... instead of `sa.DateTime` and `sa.TIMESTAMP`. Introduce `visit_TIMESTAMP` from PGTypeCompiler to render SQL DDL clauses like `TIMESTAMP WITH|WITHOUT TIME ZONE`. --- src/sqlalchemy_cratedb/compiler.py | 12 ++++++- src/sqlalchemy_cratedb/dialect.py | 11 ++++--- tests/datetime_test.py | 53 +++++++++++++++++------------- 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index d9e3bb7..6759400 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -225,7 +225,7 @@ def visit_SMALLINT(self, type_, **kw): return 'SHORT' def visit_datetime(self, type_, **kw): - return 'TIMESTAMP' + return self.visit_TIMESTAMP(type_, **kw) def visit_date(self, type_, **kw): return 'TIMESTAMP' @@ -245,6 +245,16 @@ def visit_FLOAT_VECTOR(self, type_, **kw): raise ValueError("FloatVector must be initialized with dimension size") return f"FLOAT_VECTOR({dimensions})" + def visit_TIMESTAMP(self, type_, **kw): + """ + Support for `TIMESTAMP WITH|WITHOUT TIME ZONE`. + + From `sqlalchemy.dialects.postgresql.base.PGTypeCompiler`. + """ + return "TIMESTAMP %s" % ( + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + class CrateCompiler(compiler.SQLCompiler): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 1615b53..87379c8 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -39,8 +39,8 @@ "boolean": sqltypes.Boolean, "short": sqltypes.SmallInteger, "smallint": sqltypes.SmallInteger, - "timestamp": sqltypes.TIMESTAMP, - "timestamp with time zone": sqltypes.TIMESTAMP, + "timestamp": sqltypes.TIMESTAMP(timezone=False), + "timestamp with time zone": sqltypes.TIMESTAMP(timezone=True), "object": ObjectType, "integer": sqltypes.Integer, "long": sqltypes.NUMERIC, @@ -61,8 +61,8 @@ TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean) TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger) TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger) - TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP) - TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP) + TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False)) + TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True)) TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC) TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC) TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL) @@ -147,8 +147,9 @@ def process(value): colspecs = { + sqltypes.Date: Date, sqltypes.DateTime: DateTime, - sqltypes.Date: Date + sqltypes.TIMESTAMP: DateTime, } diff --git a/tests/datetime_test.py b/tests/datetime_test.py index d127e52..a0c8193 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -21,7 +21,7 @@ from __future__ import absolute_import -from datetime import datetime, tzinfo, timedelta +from datetime import tzinfo, timedelta import datetime as dt from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock @@ -31,6 +31,7 @@ from sqlalchemy.orm import Session, sessionmaker from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.dialect import DateTime try: from sqlalchemy.orm import declarative_base @@ -78,7 +79,7 @@ class Character(Base): __tablename__ = 'characters' name = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.Date) - timestamp = sa.Column(sa.DateTime) + datetime = sa.Column(sa.DateTime) fake_cursor.description = ( ('characters_name', None, None, None, None, None, None), @@ -100,7 +101,7 @@ def test_date_can_handle_datetime(self): def test_date_can_handle_tz_aware_datetime(self): character = self.Character() character.name = "Athur" - character.timestamp = INPUT_DATETIME_NOTZ + character.datetime = INPUT_DATETIME_NOTZ self.session.add(character) @@ -111,8 +112,8 @@ class FooBar(Base): __tablename__ = "foobar" name = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.Date) - datetime = sa.Column(sa.DateTime) - timestamp = sa.Column(sa.TIMESTAMP) + datetime_notz = sa.Column(DateTime(timezone=False)) + datetime_tz = sa.Column(DateTime(timezone=True)) @pytest.fixture @@ -135,23 +136,27 @@ def test_datetime_notz(session): foo_item = FooBar( name="foo", date=INPUT_DATE, - datetime=INPUT_DATETIME_NOTZ, - timestamp=INPUT_DATETIME_NOTZ, + datetime_notz=INPUT_DATETIME_NOTZ, + datetime_tz=INPUT_DATETIME_NOTZ, ) session.add(foo_item) session.commit() session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime, FooBar.timestamp)).mappings().first() + result = session.execute(sa.select( + FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() # Compare outcome. assert result["date"] == OUTPUT_DATE - assert result["datetime"] == OUTPUT_DATETIME_NOTZ - assert result["timestamp"] == OUTPUT_DATETIME_NOTZ - assert result["datetime"].tzname() is None - assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000) - assert result["datetime"].tzinfo is None + assert result["datetime_notz"] == OUTPUT_DATETIME_NOTZ + assert result["datetime_notz"].tzname() is None + assert result["datetime_notz"].timetz() == OUTPUT_TIME + assert result["datetime_notz"].tzinfo is None + assert result["datetime_tz"] == OUTPUT_DATETIME_NOTZ + assert result["datetime_tz"].tzname() is None + assert result["datetime_tz"].timetz() == OUTPUT_TIME + assert result["datetime_tz"].tzinfo is None @pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3") @@ -163,21 +168,25 @@ def test_datetime_tz(session): # Insert record. foo_item = FooBar( name="foo", - date=dt.date(2009, 5, 13), - datetime=INPUT_DATETIME_TZ, - timestamp=INPUT_DATETIME_TZ, + date=INPUT_DATE, + datetime_notz=INPUT_DATETIME_TZ, + datetime_tz=INPUT_DATETIME_TZ, ) session.add(foo_item) session.commit() session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime, FooBar.timestamp)).mappings().first() + result = session.execute(sa.select( + FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() # Compare outcome. assert result["date"] == OUTPUT_DATE - assert result["datetime"] == OUTPUT_DATETIME_TZ - assert result["timestamp"] == OUTPUT_DATETIME_TZ - assert result["datetime"].tzname() is None - assert result["datetime"].timetz() == OUTPUT_TIME - assert result["datetime"].tzinfo is None + assert result["datetime_notz"] == OUTPUT_DATETIME_TZ + assert result["datetime_notz"].tzname() is None + assert result["datetime_notz"].timetz() == OUTPUT_TIME + assert result["datetime_notz"].tzinfo is None + assert result["datetime_tz"] == OUTPUT_DATETIME_TZ + assert result["datetime_tz"].tzname() is None + assert result["datetime_tz"].timetz() == OUTPUT_TIME + assert result["datetime_tz"].tzinfo is None