Skip to content

Commit

Permalink
chore: Harmonize imports of sqlalchemy module, use sa where applicable
Browse files Browse the repository at this point in the history
It follows a convention to import SQLAlchemy like
`import sqlalchemy as sa`. In this spirit, all references, even simple
ones like symbols to SQLAlchemy base types like `TEXT`, or `BIGINT`,
will be referenced by `sa.TEXT`, `sa.BIGINT`, etc., so it is easy to
tell them apart when harmonizing type definitions coming from SA's
built-in dialects vs. type definitions coming from 3rd-party dialects.
  • Loading branch information
amotl committed Dec 16, 2023
1 parent e3a2c4c commit 74296f7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 108 deletions.
79 changes: 32 additions & 47 deletions target_cratedb/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,10 @@
from builtins import issubclass
from datetime import datetime

import sqlalchemy
import sqlalchemy as sa
from crate.client.sqlalchemy.types import ObjectType, ObjectTypeImpl, _ObjectArray
from singer_sdk import typing as th
from singer_sdk.helpers._typing import is_array_type, is_boolean_type, is_integer_type, is_number_type, is_object_type
from sqlalchemy.types import (
ARRAY,
BIGINT,
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
FLOAT,
INTEGER,
TEXT,
TIME,
TIMESTAMP,
VARCHAR,
)
from target_postgres.connector import NOTYPE, PostgresConnector

from target_cratedb.sqlalchemy.patch import polyfill_refresh_after_dml_engine
Expand All @@ -39,7 +24,7 @@ class CrateDBConnector(PostgresConnector):
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = False # Whether temp tables are supported.

def create_engine(self) -> sqlalchemy.Engine:
def create_engine(self) -> sa.Engine:
"""
Create an SQLAlchemy engine object.
Expand All @@ -50,7 +35,7 @@ def create_engine(self) -> sqlalchemy.Engine:
return engine

@staticmethod
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""Return a JSON Schema representation of the provided type.
Note: Needs to be patched to invoke other static methods on `CrateDBConnector`.
Expand Down Expand Up @@ -112,7 +97,7 @@ def pick_individual_type(jsonschema_type: dict):
if "null" in jsonschema_type["type"]:
return None
if "integer" in jsonschema_type["type"]:
return BIGINT()
return sa.BIGINT()
if "object" in jsonschema_type["type"]:
return ObjectType
if "array" in jsonschema_type["type"]:
Expand Down Expand Up @@ -157,16 +142,16 @@ def pick_individual_type(jsonschema_type: dict):
# Discover/translate inner types.
inner_type = resolve_array_inner_type(jsonschema_type)
if inner_type is not None:
return ARRAY(inner_type)
return sa.ARRAY(inner_type)

# When type discovery fails, assume `TEXT`.
return ARRAY(TEXT())
return sa.ARRAY(sa.TEXT())

if jsonschema_type.get("format") == "date-time":
return TIMESTAMP()
return sa.TIMESTAMP()
individual_type = th.to_sql_type(jsonschema_type)
if isinstance(individual_type, VARCHAR):
return TEXT()
if isinstance(individual_type, sa.VARCHAR):
return sa.TEXT()
return individual_type

@staticmethod
Expand All @@ -182,18 +167,18 @@ def pick_best_sql_type(sql_type_array: list):
An instance of the best SQL type class based on defined precedence order.
"""
precedence_order = [
TEXT,
TIMESTAMP,
DATETIME,
DATE,
TIME,
DECIMAL,
FLOAT,
BIGINT,
INTEGER,
BOOLEAN,
sa.TEXT,
sa.TIMESTAMP,
sa.DATETIME,
sa.DATE,
sa.TIME,
sa.DECIMAL,
sa.FLOAT,
sa.BIGINT,
sa.INTEGER,
sa.BOOLEAN,
NOTYPE,
ARRAY,
sa.ARRAY,
FloatVector,
ObjectTypeImpl,
]
Expand All @@ -202,12 +187,12 @@ def pick_best_sql_type(sql_type_array: list):
for obj in sql_type_array:
if isinstance(obj, sql_type):
return obj
return TEXT()
return sa.TEXT()

Check warning on line 190 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L190

Added line #L190 was not covered by tests

def _sort_types(
self,
sql_types: t.Iterable[sqlalchemy.types.TypeEngine],
) -> list[sqlalchemy.types.TypeEngine]:
sql_types: t.Iterable[sa.types.TypeEngine],
) -> list[sa.types.TypeEngine]:
"""Return the input types sorted from most to least compatible.
Note: Needs to be patched to supply handlers for `_ObjectArray` and `NOTYPE`.
Expand All @@ -227,7 +212,7 @@ def _sort_types(
"""

def _get_type_sort_key(
sql_type: sqlalchemy.types.TypeEngine,
sql_type: sa.types.TypeEngine,
) -> tuple[int, int]:
# return rank, with higher numbers ranking first

Expand Down Expand Up @@ -257,10 +242,10 @@ def _get_type_sort_key(
def copy_table_structure(
self,
full_table_name: str,
from_table: sqlalchemy.Table,
connection: sqlalchemy.engine.Connection,
from_table: sa.Table,
connection: sa.engine.Connection,
as_temp_table: bool = False,
) -> sqlalchemy.Table:
) -> sa.Table:
"""Copy table structure.
Note: Needs to be patched to prevent `Primary key columns cannot be nullable` errors.
Expand All @@ -275,17 +260,17 @@ def copy_table_structure(
The new table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
meta = sa.MetaData(schema=schema_name)

Check warning on line 263 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L263

Added line #L263 was not covered by tests
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
column: sqlalchemy.Column
column: sa.Column
for column in from_table.columns:
# CrateDB: Prevent `Primary key columns cannot be nullable` errors.
if column.primary_key and column.nullable:
column.nullable = False
columns.append(column._copy())
new_table = sqlalchemy.Table(table_name, meta, *columns)
new_table = sa.Table(table_name, meta, *columns)

Check warning on line 273 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L273

Added line #L273 was not covered by tests
new_table.create(bind=connection)
return new_table

Expand All @@ -299,11 +284,11 @@ def prepare_schema(self, schema_name: str) -> None:
def resolve_array_inner_type(jsonschema_type: dict) -> t.Union[sa.types.TypeEngine, None]:
if "items" in jsonschema_type:
if is_boolean_type(jsonschema_type["items"]):
return BOOLEAN()
return sa.BOOLEAN()
if is_number_type(jsonschema_type["items"]):
return FLOAT()
return sa.FLOAT()
if is_integer_type(jsonschema_type["items"]):
return BIGINT()
return sa.BIGINT()

Check warning on line 291 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L291

Added line #L291 was not covered by tests
if is_object_type(jsonschema_type["items"]):
return ObjectType()

Check warning on line 293 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L293

Added line #L293 was not covered by tests
if is_array_type(jsonschema_type["items"]):
Expand Down
55 changes: 27 additions & 28 deletions target_cratedb/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import time
from typing import List, Optional, Union

import sqlalchemy
import sqlalchemy as sa
from pendulum import now
from sqlalchemy import Column, Executable, MetaData, Table, bindparam, insert, select, update
from target_postgres.sinks import PostgresSink

from target_cratedb.connector import CrateDBConnector
Expand Down Expand Up @@ -116,7 +115,7 @@ def process_batch(self, context: dict) -> None:
# Use one connection so we do this all in a single transaction
with self.connector._connect() as connection, connection.begin():
# Check structure of table
table: sqlalchemy.Table = self.connector.prepare_table(
table: sa.Table = self.connector.prepare_table(
full_table_name=self.full_table_name,
schema=self.schema,
primary_keys=self.key_properties,
Expand All @@ -134,7 +133,7 @@ def process_batch(self, context: dict) -> None:
# FIXME: Upserts do not work yet.
"""
# Create a temp table (Creates from the table above)
temp_table: sqlalchemy.Table = self.connector.copy_table_structure(
temp_table: sa.Table = self.connector.copy_table_structure(
full_table_name=self.temp_table_name,
from_table=table,
as_temp_table=True,
Expand Down Expand Up @@ -162,11 +161,11 @@ def process_batch(self, context: dict) -> None:

def upsertX(
self,
from_table: sqlalchemy.Table,
to_table: sqlalchemy.Table,
from_table: sa.Table,
to_table: sa.Table,
schema: dict,
join_keys: List[Column],
connection: sqlalchemy.engine.Connection,
join_keys: List[sa.Column],
connection: sa.engine.Connection,
) -> Optional[int]:
"""Merge upsert data from one table to another.
Expand All @@ -185,45 +184,45 @@ def upsertX(

if self.append_only is True:
# Insert
select_stmt = select(from_table.columns).select_from(from_table)
select_stmt = sa.select(from_table.columns).select_from(from_table)

Check warning on line 187 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L187

Added line #L187 was not covered by tests
insert_stmt = to_table.insert().from_select(names=list(from_table.columns), select=select_stmt)
connection.execute(insert_stmt)
else:
join_predicates = []
for key in join_keys:
from_table_key: sqlalchemy.Column = from_table.columns[key] # type: ignore[call-overload]
to_table_key: sqlalchemy.Column = to_table.columns[key] # type: ignore[call-overload]
from_table_key: sa.Column = from_table.columns[key] # type: ignore[call-overload]
to_table_key: sa.Column = to_table.columns[key] # type: ignore[call-overload]

Check warning on line 194 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L193-L194

Added lines #L193 - L194 were not covered by tests
join_predicates.append(from_table_key == to_table_key) # type: ignore[call-overload]

join_condition = sqlalchemy.and_(*join_predicates)
join_condition = sa.and_(*join_predicates)

Check warning on line 197 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L197

Added line #L197 was not covered by tests

where_predicates = []
for key in join_keys:
to_table_key: sqlalchemy.Column = to_table.columns[key] # type: ignore[call-overload,no-redef]
to_table_key: sa.Column = to_table.columns[key] # type: ignore[call-overload,no-redef]

Check warning on line 201 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L201

Added line #L201 was not covered by tests
where_predicates.append(to_table_key.is_(None))
where_condition = sqlalchemy.and_(*where_predicates)
where_condition = sa.and_(*where_predicates)

Check warning on line 203 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L203

Added line #L203 was not covered by tests

select_stmt = (
select(from_table.columns)
sa.select(from_table.columns)
.select_from(from_table.outerjoin(to_table, join_condition))
.where(where_condition)
)
insert_stmt = insert(to_table).from_select(names=list(from_table.columns), select=select_stmt)
insert_stmt = sa.insert(to_table).from_select(names=list(from_table.columns), select=select_stmt)

Check warning on line 210 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L210

Added line #L210 was not covered by tests

connection.execute(insert_stmt)

# Update
where_condition = join_condition
update_columns = {}
for column_name in self.schema["properties"].keys():
from_table_column: sqlalchemy.Column = from_table.columns[column_name]
to_table_column: sqlalchemy.Column = to_table.columns[column_name]
from_table_column: sa.Column = from_table.columns[column_name]
to_table_column: sa.Column = to_table.columns[column_name]

Check warning on line 219 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L218-L219

Added lines #L218 - L219 were not covered by tests
# Prevent: `Updating a primary key is not supported`
if to_table_column.primary_key:
continue
update_columns[to_table_column] = from_table_column

update_stmt = update(to_table).where(where_condition).values(update_columns)
update_stmt = sa.update(to_table).where(where_condition).values(update_columns)

Check warning on line 225 in target_cratedb/sinks.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sinks.py#L225

Added line #L225 was not covered by tests
connection.execute(update_stmt)

return None
Expand Down Expand Up @@ -264,7 +263,7 @@ def activate_version(self, new_version: int) -> None:
self.logger.info("Hard delete: %s", self.config.get("hard_delete"))
if self.config["hard_delete"] is True:
connection.execute(
sqlalchemy.text(
sa.text(
f'DELETE FROM "{self.schema_name}"."{self.table_name}" ' # noqa: S608
f'WHERE "{self.version_column_name}" <= {new_version} '
f'OR "{self.version_column_name}" IS NULL'
Expand All @@ -284,24 +283,24 @@ def activate_version(self, new_version: int) -> None:
connection=connection,
)
# Need to deal with the case where data doesn't exist for the version column
query = sqlalchemy.text(
query = sa.text(
f'UPDATE "{self.schema_name}"."{self.table_name}"\n'
f'SET "{self.soft_delete_column_name}" = :deletedate \n'
f'WHERE "{self.version_column_name}" < :version '
f'OR "{self.version_column_name}" IS NULL \n'
f' AND "{self.soft_delete_column_name}" IS NULL\n'
)
query = query.bindparams(
bindparam("deletedate", value=deleted_at, type_=datetime_type),
bindparam("version", value=new_version, type_=integer_type),
sa.bindparam("deletedate", value=deleted_at, type_=datetime_type),
sa.bindparam("version", value=new_version, type_=integer_type),
)
connection.execute(query)

def generate_insert_statement(
self,
full_table_name: str,
columns: List[Column],
) -> Union[str, Executable]:
columns: List[sa.Column],
) -> Union[str, sa.sql.Executable]:
"""Generate an insert statement for the given records.
Args:
Expand All @@ -312,6 +311,6 @@ def generate_insert_statement(
An insert statement.
"""
# FIXME:
metadata = MetaData(schema=self.schema_name)
table = Table(full_table_name, metadata, *columns)
return insert(table)
metadata = sa.MetaData(schema=self.schema_name)
table = sa.Table(full_table_name, metadata, *columns)
return sa.insert(table)
Loading

0 comments on commit 74296f7

Please sign in to comment.