Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Override superclass to_jsonschema in PostgresSQLToJSONSchema #546

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions tap_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import json
import select
import typing as t
from functools import cached_property
from types import MappingProxyType
from typing import TYPE_CHECKING, Any

import psycopg2
import singer_sdk.helpers._typing
Expand All @@ -26,11 +24,10 @@
from singer_sdk.streams.core import REPLICATION_INCREMENTAL
from sqlalchemy.dialects import postgresql

if TYPE_CHECKING:
if t.TYPE_CHECKING:
from collections.abc import Iterable, Mapping

from singer_sdk.helpers.types import Context
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector

Expand All @@ -44,7 +41,12 @@ def __init__(self, dates_as_string: bool, json_as_object: bool, *args, **kwargs)
self.dates_as_string = dates_as_string
self.json_as_object = json_as_object

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@functools.singledispatchmethod
def to_jsonschema(self, column_type: t.Any) -> dict:
"""Customize the JSON Schema for Postgres types."""
return super().to_jsonschema(column_type)

@to_jsonschema.register
def array_to_jsonschema(self, column_type: postgresql.ARRAY) -> dict:
"""Override the default mapping for NUMERIC columns.

Expand All @@ -55,32 +57,29 @@ def array_to_jsonschema(self, column_type: postgresql.ARRAY) -> dict:
"items": self.to_jsonschema(column_type.item_type),
}

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def json_to_jsonschema(self, column_type: postgresql.JSON) -> dict:
"""Override the default mapping for JSON and JSONB columns."""
if self.json_as_object:
return {"type": ["object", "null"]}
return {"type": ["string", "number", "integer", "array", "object", "boolean"]}

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def datetime_to_jsonschema(self, column_type: sqlalchemy.types.DateTime) -> dict:
"""Override the default mapping for DATETIME columns."""
if self.dates_as_string:
return {"type": ["string", "null"]}
return super().datetime_to_jsonschema(column_type)

@SQLToJSONSchema.to_jsonschema.register # type: ignore[attr-defined]
@to_jsonschema.register
def date_to_jsonschema(self, column_type: sqlalchemy.types.Date) -> dict:
"""Override the default mapping for DATE columns."""
if self.dates_as_string:
return {"type": ["string", "null"]}
return super().date_to_jsonschema(column_type)


def patched_conform(
elem: Any,
property_schema: dict,
) -> Any:
def patched_conform(elem: t.Any, property_schema: dict) -> t.Any:
"""Overrides Singer SDK type conformance.

Most logic here is from singer_sdk.helpers._typing._conform_primitive_property, as
Expand Down Expand Up @@ -272,11 +271,11 @@ class PostgresLogBasedStream(SQLStream):
replication_key = "_sdc_lsn"

@property
def config(self) -> Mapping[str, Any]:
def config(self) -> Mapping[str, t.Any]:
"""Return a read-only config dictionary."""
return MappingProxyType(self._config)

@cached_property
@functools.cached_property
def schema(self) -> dict:
"""Override schema for log-based replication adding _sdc columns."""
schema_dict = t.cast(dict, self._singer_catalog_entry.schema.to_dict())
Expand All @@ -293,7 +292,7 @@ def schema(self) -> dict:

def _increment_stream_state(
self,
latest_record: dict[str, Any],
latest_record: dict[str, t.Any],
*,
context: Context | None = None,
) -> None:
Expand Down Expand Up @@ -326,7 +325,7 @@ def _increment_stream_state(
check_sorted=self.check_sorted,
)

def get_records(self, context: Context | None) -> Iterable[dict[str, Any]]:
def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]:
"""Return a generator of row-type dictionary objects."""
status_interval = 5.0 # if no records in 5 seconds the tap can exit
start_lsn = self.get_starting_replication_key_value(context=context)
Expand Down