diff --git a/target_postgres/sinks.py b/target_postgres/sinks.py index f2f9e420..c5f65444 100644 --- a/target_postgres/sinks.py +++ b/target_postgres/sinks.py @@ -5,6 +5,7 @@ import datetime import typing as t import uuid +from io import StringIO import sqlalchemy as sa from singer_sdk.sinks import SQLSink @@ -14,7 +15,6 @@ if t.TYPE_CHECKING: from singer_sdk.connectors.sql import FullyQualifiedName - from sqlalchemy.sql import Executable class PostgresSink(SQLSink): @@ -145,35 +145,90 @@ def bulk_insert_records( # type: ignore[override] True if table exists, False if not, None if unsure or undetectable. """ columns = self.column_representation(schema) - insert: str = t.cast( - str, - self.generate_insert_statement( - table.name, - columns, - ), - ) - self.logger.info("Inserting with SQL: %s", insert) + copy_statement: str = self.generate_copy_statement(table.name, columns) + self.logger.info("Inserting with SQL: %s", copy_statement) # Only one record per PK, we want to take the last one - data_to_insert: list[dict[str, t.Any]] = [] + data_to_insert: tuple[tuple[t.Any, ...], ...] if self.append_only is False: - insert_records: dict[tuple, dict] = {} # pk tuple: record + copy_values: dict[tuple, tuple] = {} # pk tuple: values for record in records: - insert_record = { - column.name: record.get(column.name) for column in columns - } + values = tuple(record.get(column.name) for column in columns) # No need to check for a KeyError here because the SDK already # guarantees that all key properties exist in the record. primary_key_tuple = tuple(record[key] for key in primary_keys) - insert_records[primary_key_tuple] = insert_record - data_to_insert = list(insert_records.values()) + copy_values[primary_key_tuple] = values + data_to_insert = tuple(copy_values.values()) else: - for record in records: - insert_record = { - column.name: record.get(column.name) for column in columns - } - data_to_insert.append(insert_record) - connection.execute(insert, data_to_insert) + data_to_insert = tuple( + tuple(record.get(column.name) for column in columns) + for record in records + ) + + # Prepare to process the rows into csv. Use each column's bind_processor to do + # most of the work, then do the final construction of the csv rows ourselves + # to control exactly how values are converted and which ones are quoted. + column_processors = [ + column.type.bind_processor(connection.dialect) or str for column in columns + ] + + # Make translation table for escaping in array values. + str_translate_table = str.maketrans( + { + '"': '""', + "\\": "\\\\", + } + ) + array_translate_table = str.maketrans( + { + '"': '\\""', + "\\": "\\\\", + } + ) + + def process_column_value(data: t.Any, proc: t.Callable) -> str: + # If the data is null, return an unquoted, empty value. + # Unquoted is important here, for PostgreSQL to interpret as null. + if data is None: + return "" + + # Pass the Python value through the bind_processor. + value = proc(data) + + # If the value is a string, escape double-quotes as "" and return + # a quoted value. + if isinstance(value, str): + # escape double quotes as "". + return '"' + value.translate(str_translate_table) + '"' + + # If the value is a list (for ARRAY), escape double-quotes as \" and return + # a quoted value in literal array format. + if isinstance(value, list): + # for each member of value, escape double quotes as \". + return ( + '"{' + + ",".join( + '""' + v.translate(array_translate_table) + '""' for v in value + ) + + '}"' + ) + + # Otherwise, return the string representation of the value. + return str(value) + + buffer = StringIO() + for row in data_to_insert: + processed_row = ",".join(map(process_column_value, row, column_processors)) + + buffer.write(processed_row) + buffer.write("\n") + buffer.seek(0) + + # Use copy_expert to run the copy statement. + # https://www.psycopg.org/docs/cursor.html#cursor.copy_expert + with connection.connection.cursor() as cur: # type: ignore[attr-defined] + cur.copy_expert(sql=copy_statement, file=buffer) + return True def upsert( @@ -261,23 +316,24 @@ def column_representation( ] return columns - def generate_insert_statement( + def generate_copy_statement( self, full_table_name: str | FullyQualifiedName, columns: list[sa.Column], # type: ignore[override] - ) -> str | Executable: - """Generate an insert statement for the given records. + ) -> str: + """Generate a copy statement for bulk copy. Args: full_table_name: the target table name. columns: the target table columns. Returns: - An insert statement. + A copy statement. """ - metadata = sa.MetaData() - table = sa.Table(full_table_name, metadata, *columns) - return sa.insert(table) + columns_list = ", ".join(f'"{column.name}"' for column in columns) + sql: str = f'copy "{full_table_name}" ({columns_list}) from stdin with csv' + + return sql def conform_name(self, name: str, object_type: str | None = None) -> str: """Conforming names of tables, schemas, column names."""