diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index d9e3bb7..6759400 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -225,7 +225,7 @@ def visit_SMALLINT(self, type_, **kw): return 'SHORT' def visit_datetime(self, type_, **kw): - return 'TIMESTAMP' + return self.visit_TIMESTAMP(type_, **kw) def visit_date(self, type_, **kw): return 'TIMESTAMP' @@ -245,6 +245,16 @@ def visit_FLOAT_VECTOR(self, type_, **kw): raise ValueError("FloatVector must be initialized with dimension size") return f"FLOAT_VECTOR({dimensions})" + def visit_TIMESTAMP(self, type_, **kw): + """ + Support for `TIMESTAMP WITH|WITHOUT TIME ZONE`. + + From `sqlalchemy.dialects.postgresql.base.PGTypeCompiler`. + """ + return "TIMESTAMP %s" % ( + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + class CrateCompiler(compiler.SQLCompiler): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 1615b53..87379c8 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -39,8 +39,8 @@ "boolean": sqltypes.Boolean, "short": sqltypes.SmallInteger, "smallint": sqltypes.SmallInteger, - "timestamp": sqltypes.TIMESTAMP, - "timestamp with time zone": sqltypes.TIMESTAMP, + "timestamp": sqltypes.TIMESTAMP(timezone=False), + "timestamp with time zone": sqltypes.TIMESTAMP(timezone=True), "object": ObjectType, "integer": sqltypes.Integer, "long": sqltypes.NUMERIC, @@ -61,8 +61,8 @@ TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean) TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger) TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger) - TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP) - TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP) + TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False)) + TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True)) TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC) TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC) TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL) @@ -147,8 +147,9 @@ def process(value): colspecs = { + sqltypes.Date: Date, sqltypes.DateTime: DateTime, - sqltypes.Date: Date + sqltypes.TIMESTAMP: DateTime, } diff --git a/tests/create_table_test.py b/tests/create_table_test.py index f74c45e..60e67b1 100644 --- a/tests/create_table_test.py +++ b/tests/create_table_test.py @@ -67,8 +67,10 @@ class User(self.Base): '\n\tlong_col1 LONG, \n\tlong_col2 LONG, ' '\n\tbool_col BOOLEAN, ' '\n\tshort_col SHORT, ' - '\n\tdatetime_col TIMESTAMP, \n\tdate_col TIMESTAMP, ' - '\n\tfloat_col FLOAT, \n\tdouble_col DOUBLE, ' + '\n\tdatetime_col TIMESTAMP WITHOUT TIME ZONE, ' + '\n\tdate_col TIMESTAMP, ' + '\n\tfloat_col FLOAT, ' + '\n\tdouble_col DOUBLE, ' '\n\tPRIMARY KEY (string_col)\n)\n\n'), ()) @@ -271,7 +273,7 @@ class DummyTable(self.Base): fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' - 'a TIMESTAMP DEFAULT now(), \n\t' + 'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t' 'PRIMARY KEY (pk)\n)\n\n'), ()) def test_column_server_default_string(self): @@ -297,7 +299,7 @@ class DummyTable(self.Base): fake_cursor.execute.assert_called_with( ('\nCREATE TABLE t (\n\t' 'pk STRING NOT NULL, \n\t' - 'a TIMESTAMP DEFAULT now(), \n\t' + 'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t' 'PRIMARY KEY (pk)\n)\n\n'), ()) def test_column_server_default_text_constant(self): diff --git a/tests/datetime_test.py b/tests/datetime_test.py index e52ca53..a0c8193 100644 --- a/tests/datetime_test.py +++ b/tests/datetime_test.py @@ -21,7 +21,7 @@ from __future__ import absolute_import -from datetime import datetime, tzinfo, timedelta +from datetime import tzinfo, timedelta import datetime as dt from unittest import TestCase, skipIf from unittest.mock import patch, MagicMock @@ -31,6 +31,7 @@ from sqlalchemy.orm import Session, sessionmaker from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb.dialect import DateTime try: from sqlalchemy.orm import declarative_base @@ -57,6 +58,15 @@ def dst(self, date_time): return timedelta(seconds=-7200) +INPUT_DATE = dt.date(2009, 5, 13) +INPUT_DATETIME_NOTZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123456) +INPUT_DATETIME_TZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123456, tzinfo=CST()) +OUTPUT_DATE = INPUT_DATE +OUTPUT_TIME = dt.time(19, 19, 30, 123000) +OUTPUT_DATETIME_NOTZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123000) +OUTPUT_DATETIME_TZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123000) + + @skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases") @patch('crate.client.connection.Cursor', FakeCursor) class SqlAlchemyDateAndDateTimeTest(TestCase): @@ -69,7 +79,7 @@ class Character(Base): __tablename__ = 'characters' name = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.Date) - timestamp = sa.Column(sa.DateTime) + datetime = sa.Column(sa.DateTime) fake_cursor.description = ( ('characters_name', None, None, None, None, None, None), @@ -91,7 +101,7 @@ def test_date_can_handle_datetime(self): def test_date_can_handle_tz_aware_datetime(self): character = self.Character() character.name = "Athur" - character.timestamp = datetime(2009, 5, 13, 19, 19, 30, tzinfo=CST()) + character.datetime = INPUT_DATETIME_NOTZ self.session.add(character) @@ -102,7 +112,8 @@ class FooBar(Base): __tablename__ = "foobar" name = sa.Column(sa.String, primary_key=True) date = sa.Column(sa.Date) - datetime = sa.Column(sa.DateTime) + datetime_notz = sa.Column(DateTime(timezone=False)) + datetime_tz = sa.Column(DateTime(timezone=True)) @pytest.fixture @@ -124,22 +135,28 @@ def test_datetime_notz(session): # Insert record. foo_item = FooBar( name="foo", - date=dt.date(2009, 5, 13), - datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456), + date=INPUT_DATE, + datetime_notz=INPUT_DATETIME_NOTZ, + datetime_tz=INPUT_DATETIME_NOTZ, ) session.add(foo_item) session.commit() session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first() + result = session.execute(sa.select( + FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() # Compare outcome. - assert result["date"].year == 2009 - assert result["datetime"].year == 2009 - assert result["datetime"].tzname() is None - assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000) - assert result["datetime"].tzinfo is None + assert result["date"] == OUTPUT_DATE + assert result["datetime_notz"] == OUTPUT_DATETIME_NOTZ + assert result["datetime_notz"].tzname() is None + assert result["datetime_notz"].timetz() == OUTPUT_TIME + assert result["datetime_notz"].tzinfo is None + assert result["datetime_tz"] == OUTPUT_DATETIME_NOTZ + assert result["datetime_tz"].tzname() is None + assert result["datetime_tz"].timetz() == OUTPUT_TIME + assert result["datetime_tz"].tzinfo is None @pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3") @@ -151,19 +168,25 @@ def test_datetime_tz(session): # Insert record. foo_item = FooBar( name="foo", - date=dt.date(2009, 5, 13), - datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456, tzinfo=CST()), + date=INPUT_DATE, + datetime_notz=INPUT_DATETIME_TZ, + datetime_tz=INPUT_DATETIME_TZ, ) session.add(foo_item) session.commit() session.execute(sa.text("REFRESH TABLE foobar")) # Query record. - result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first() + result = session.execute(sa.select( + FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first() # Compare outcome. - assert result["date"].year == 2009 - assert result["datetime"].year == 2009 - assert result["datetime"].tzname() is None - assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000) - assert result["datetime"].tzinfo is None + assert result["date"] == OUTPUT_DATE + assert result["datetime_notz"] == OUTPUT_DATETIME_TZ + assert result["datetime_notz"].tzname() is None + assert result["datetime_notz"].timetz() == OUTPUT_TIME + assert result["datetime_notz"].tzinfo is None + assert result["datetime_tz"] == OUTPUT_DATETIME_TZ + assert result["datetime_tz"].tzname() is None + assert result["datetime_tz"].timetz() == OUTPUT_TIME + assert result["datetime_tz"].tzinfo is None