diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 1b7204c7..1dfba79c 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -28,6 +28,7 @@ ) from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.interfaces import Dialect as RootDialect from sqlalchemy.engine.reflection import cache from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -47,7 +48,7 @@ if TYPE_CHECKING: from sqlalchemy.base import Connection from sqlalchemy.engine.interfaces import _IndexDict - + from sqlalchemy.sql.type_api import _ResultProcessor register_extension_types() @@ -215,6 +216,16 @@ def quote_schema(self, schema: str, force: Any = None) -> str: return self.format_schema(schema) +class DuckDBNullType(sqltypes.NullType): + def result_processor( + self, dialect: RootDialect, coltype: sqltypes.TypeEngine + ) -> Optional["_ResultProcessor"]: + if coltype == "JSON": + return sqltypes.JSON().result_processor(dialect, coltype) + else: + return super().result_processor(dialect, coltype) + + class Dialect(PGDialect_psycopg2): name = "duckdb" driver = "duckdb_engine" @@ -247,6 +258,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["use_native_hstore"] = False super().__init__(*args, **kwargs) + def type_descriptor(self, typeobj: Type[sqltypes.TypeEngine]) -> Any: # type: ignore[override] + res = super().type_descriptor(typeobj) + + if isinstance(res, sqltypes.NullType): + return DuckDBNullType() + + return res + def connect(self, *cargs: Any, **cparams: Any) -> "Connection": core_keys = get_core_config() preload_extensions = cparams.pop("preload_extensions", []) diff --git a/duckdb_engine/datatypes.py b/duckdb_engine/datatypes.py index b2f7cc1e..bec1e3c3 100644 --- a/duckdb_engine/datatypes.py +++ b/duckdb_engine/datatypes.py @@ -188,6 +188,7 @@ def __init__(self, fields: Dict[str, TV]): "timestamp_ms": sqltypes.TIMESTAMP, "timestamp_ns": sqltypes.TIMESTAMP, "enum": sqltypes.Enum, + "json": sqltypes.JSON, } diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 2f9eebc7..71c27987 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -1,12 +1,24 @@ +import decimal +import json import warnings -from typing import Type +from typing import Any, Dict, Type from uuid import uuid4 import duckdb from pytest import importorskip, mark -from sqlalchemy import Column, Integer, MetaData, String, Table, inspect, text +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + String, + Table, + inspect, + select, + text, +) from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import sqltypes @@ -45,6 +57,58 @@ def test_unsigned_integer_type( assert session.query(table).one() +@mark.remote_data() +def test_raw_json(engine: Engine) -> None: + importorskip("duckdb", "0.9.3.dev4040") + + with engine.connect() as conn: + assert conn.execute(text("load json")) + + assert conn.execute(text("select {'Hello': 'world'}::JSON")).fetchone() == ( + {"Hello": "world"}, + ) + + +@mark.remote_data() +def test_custom_json_serializer() -> None: + def default(o: Any) -> Any: + if isinstance(o, decimal.Decimal): + return {"__tag": "decimal", "value": str(o)} + + def object_hook(pairs: Dict[str, Any]) -> Any: + if pairs.get("__tag", None) == "decimal": + return decimal.Decimal(pairs["value"]) + else: + return pairs + + engine = create_engine( + "duckdb://", + json_serializer=json.JSONEncoder(default=default).encode, + json_deserializer=json.JSONDecoder(object_hook=object_hook).decode, + ) + + Base = declarative_base() + + class Entry(Base): + __tablename__ = "test_json" + id = Column(Integer, Sequence("id_seq"), primary_key=True) + data = Column(JSON, nullable=False) + + Base.metadata.create_all(engine) + + with engine.connect() as conn: + session = Session(bind=conn) + + data = {"hello": decimal.Decimal("42")} + + session.add(Entry(data=data)) # type: ignore[call-arg] + session.commit() + + (res,) = session.execute(select(Entry)).one() + + assert res.data == data + + def test_json(engine: Engine, session: Session) -> None: base = declarative_base()