Skip to content

Commit

Permalink
DateTime and more: Use CrateDB's DateTime
Browse files Browse the repository at this point in the history
... instead of `sa.DateTime` and `sa.TIMESTAMP`. Introduce
`visit_TIMESTAMP` from PGTypeCompiler to render SQL DDL clauses like
`TIMESTAMP WITH|WITHOUT TIME ZONE`.
  • Loading branch information
amotl committed Jun 24, 2024
1 parent f52cd8c commit 009d233
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 30 deletions.
12 changes: 11 additions & 1 deletion src/sqlalchemy_cratedb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def visit_SMALLINT(self, type_, **kw):
return 'SHORT'

def visit_datetime(self, type_, **kw):
return 'TIMESTAMP'
return self.visit_TIMESTAMP(type_, **kw)

def visit_date(self, type_, **kw):
return 'TIMESTAMP'
Expand All @@ -245,6 +245,16 @@ def visit_FLOAT_VECTOR(self, type_, **kw):
raise ValueError("FloatVector must be initialized with dimension size")
return f"FLOAT_VECTOR({dimensions})"

def visit_TIMESTAMP(self, type_, **kw):
"""
Support for `TIMESTAMP WITH|WITHOUT TIME ZONE`.
From `sqlalchemy.dialects.postgresql.base.PGTypeCompiler`.
"""
return "TIMESTAMP %s" % (
(type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
)


class CrateCompiler(compiler.SQLCompiler):

Expand Down
11 changes: 6 additions & 5 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
"boolean": sqltypes.Boolean,
"short": sqltypes.SmallInteger,
"smallint": sqltypes.SmallInteger,
"timestamp": sqltypes.TIMESTAMP,
"timestamp with time zone": sqltypes.TIMESTAMP,
"timestamp": sqltypes.TIMESTAMP(timezone=False),
"timestamp with time zone": sqltypes.TIMESTAMP(timezone=True),
"object": ObjectType,
"integer": sqltypes.Integer,
"long": sqltypes.NUMERIC,
Expand All @@ -61,8 +61,8 @@
TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean)
TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP)
TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP)
TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False))
TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True))
TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL)
Expand Down Expand Up @@ -147,8 +147,9 @@ def process(value):


colspecs = {
sqltypes.Date: Date,
sqltypes.DateTime: DateTime,
sqltypes.Date: Date
sqltypes.TIMESTAMP: DateTime,
}


Expand Down
10 changes: 6 additions & 4 deletions tests/create_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ class User(self.Base):
'\n\tlong_col1 LONG, \n\tlong_col2 LONG, '
'\n\tbool_col BOOLEAN, '
'\n\tshort_col SHORT, '
'\n\tdatetime_col TIMESTAMP, \n\tdate_col TIMESTAMP, '
'\n\tfloat_col FLOAT, \n\tdouble_col DOUBLE, '
'\n\tdatetime_col TIMESTAMP WITHOUT TIME ZONE, '
'\n\tdate_col TIMESTAMP, '
'\n\tfloat_col FLOAT, '
'\n\tdouble_col DOUBLE, '
'\n\tPRIMARY KEY (string_col)\n)\n\n'),
())

Expand Down Expand Up @@ -271,7 +273,7 @@ class DummyTable(self.Base):
fake_cursor.execute.assert_called_with(
('\nCREATE TABLE t (\n\t'
'pk STRING NOT NULL, \n\t'
'a TIMESTAMP DEFAULT now(), \n\t'
'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t'
'PRIMARY KEY (pk)\n)\n\n'), ())

def test_column_server_default_string(self):
Expand All @@ -297,7 +299,7 @@ class DummyTable(self.Base):
fake_cursor.execute.assert_called_with(
('\nCREATE TABLE t (\n\t'
'pk STRING NOT NULL, \n\t'
'a TIMESTAMP DEFAULT now(), \n\t'
'a TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \n\t'
'PRIMARY KEY (pk)\n)\n\n'), ())

def test_column_server_default_text_constant(self):
Expand Down
63 changes: 43 additions & 20 deletions tests/datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import absolute_import

from datetime import datetime, tzinfo, timedelta
from datetime import tzinfo, timedelta
import datetime as dt
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock
Expand All @@ -31,6 +31,7 @@
from sqlalchemy.orm import Session, sessionmaker

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.dialect import DateTime

try:
from sqlalchemy.orm import declarative_base
Expand All @@ -57,6 +58,15 @@ def dst(self, date_time):
return timedelta(seconds=-7200)


INPUT_DATE = dt.date(2009, 5, 13)
INPUT_DATETIME_NOTZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123456)
INPUT_DATETIME_TZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123456, tzinfo=CST())
OUTPUT_DATE = INPUT_DATE
OUTPUT_TIME = dt.time(19, 19, 30, 123000)
OUTPUT_DATETIME_NOTZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123000)
OUTPUT_DATETIME_TZ = dt.datetime(2009, 5, 13, 19, 19, 30, 123000)


@skipIf(SA_VERSION < SA_1_4, "SQLAlchemy 1.3 suddenly has problems with these test cases")
@patch('crate.client.connection.Cursor', FakeCursor)
class SqlAlchemyDateAndDateTimeTest(TestCase):
Expand All @@ -69,7 +79,7 @@ class Character(Base):
__tablename__ = 'characters'
name = sa.Column(sa.String, primary_key=True)
date = sa.Column(sa.Date)
timestamp = sa.Column(sa.DateTime)
datetime = sa.Column(sa.DateTime)

fake_cursor.description = (
('characters_name', None, None, None, None, None, None),
Expand All @@ -91,7 +101,7 @@ def test_date_can_handle_datetime(self):
def test_date_can_handle_tz_aware_datetime(self):
character = self.Character()
character.name = "Athur"
character.timestamp = datetime(2009, 5, 13, 19, 19, 30, tzinfo=CST())
character.datetime = INPUT_DATETIME_NOTZ
self.session.add(character)


Expand All @@ -102,7 +112,8 @@ class FooBar(Base):
__tablename__ = "foobar"
name = sa.Column(sa.String, primary_key=True)
date = sa.Column(sa.Date)
datetime = sa.Column(sa.DateTime)
datetime_notz = sa.Column(DateTime(timezone=False))
datetime_tz = sa.Column(DateTime(timezone=True))


@pytest.fixture
Expand All @@ -124,22 +135,28 @@ def test_datetime_notz(session):
# Insert record.
foo_item = FooBar(
name="foo",
date=dt.date(2009, 5, 13),
datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456),
date=INPUT_DATE,
datetime_notz=INPUT_DATETIME_NOTZ,
datetime_tz=INPUT_DATETIME_NOTZ,
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first()
result = session.execute(sa.select(
FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first()

# Compare outcome.
assert result["date"].year == 2009
assert result["datetime"].year == 2009
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000)
assert result["datetime"].tzinfo is None
assert result["date"] == OUTPUT_DATE
assert result["datetime_notz"] == OUTPUT_DATETIME_NOTZ
assert result["datetime_notz"].tzname() is None
assert result["datetime_notz"].timetz() == OUTPUT_TIME
assert result["datetime_notz"].tzinfo is None
assert result["datetime_tz"] == OUTPUT_DATETIME_NOTZ
assert result["datetime_tz"].tzname() is None
assert result["datetime_tz"].timetz() == OUTPUT_TIME
assert result["datetime_tz"].tzinfo is None


@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3")
Expand All @@ -151,19 +168,25 @@ def test_datetime_tz(session):
# Insert record.
foo_item = FooBar(
name="foo",
date=dt.date(2009, 5, 13),
datetime=dt.datetime(2009, 5, 13, 19, 19, 30, 123456, tzinfo=CST()),
date=INPUT_DATE,
datetime_notz=INPUT_DATETIME_TZ,
datetime_tz=INPUT_DATETIME_TZ,
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime)).mappings().first()
result = session.execute(sa.select(
FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first()

# Compare outcome.
assert result["date"].year == 2009
assert result["datetime"].year == 2009
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000)
assert result["datetime"].tzinfo is None
assert result["date"] == OUTPUT_DATE
assert result["datetime_notz"] == OUTPUT_DATETIME_TZ
assert result["datetime_notz"].tzname() is None
assert result["datetime_notz"].timetz() == OUTPUT_TIME
assert result["datetime_notz"].tzinfo is None
assert result["datetime_tz"] == OUTPUT_DATETIME_TZ
assert result["datetime_tz"].tzname() is None
assert result["datetime_tz"].timetz() == OUTPUT_TIME
assert result["datetime_tz"].tzinfo is None

0 comments on commit 009d233

Please sign in to comment.