Skip to content

Commit

Permalink
Merge branch 'refs/heads/devel' into 1322-lancedb-usage-example-docs
Browse files Browse the repository at this point in the history
# Conflicts:
#	poetry.lock
  • Loading branch information
Pipboyguy committed May 7, 2024
2 parents 4beb3f7 + 30f0416 commit c66bf9a
Show file tree
Hide file tree
Showing 28 changed files with 464 additions and 358 deletions.
5 changes: 4 additions & 1 deletion dlt/common/data_writers/escape.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def format_datetime_literal(v: pendulum.DateTime, precision: int = 6, no_tz: boo
def format_bigquery_datetime_literal(
v: pendulum.DateTime, precision: int = 6, no_tz: bool = False
) -> str:
"""Returns BigQuery-adjusted datetime literal by prefixing required `TIMESTAMP` indicator."""
"""Returns BigQuery-adjusted datetime literal by prefixing required `TIMESTAMP` indicator.
Also works for Presto-based engines.
"""
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#timestamp_literals
return "TIMESTAMP " + format_datetime_literal(v, precision, no_tz)

Expand Down
5 changes: 4 additions & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,10 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable
return True


TDestinationReferenceArg = Union[str, "Destination", Callable[..., "Destination"], None]
# TODO: type Destination properly
TDestinationReferenceArg = Union[
str, "Destination[Any, Any]", Callable[..., "Destination[Any, Any]"], None
]


class Destination(ABC, Generic[TDestinationConfig, TDestinationClient]):
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _init_logging(
logger = logging.getLogger(logger_name)
logger.propagate = False
logger.setLevel(level)
# get or create logging handler
# get or create logging handler, we log to stderr by default
handler = next(iter(logger.handlers), logging.StreamHandler())
logger.addHandler(handler)

Expand Down
52 changes: 47 additions & 5 deletions dlt/common/runners/stdout.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import sys
import queue
from contextlib import contextmanager
from subprocess import PIPE, CalledProcessError
from threading import Thread
from typing import Any, Generator, Iterator, List
from typing import Any, Generator, Iterator, List, Tuple, Literal

from dlt.common.runners.venv import Venv
from dlt.common.runners.synth_pickle import decode_obj, decode_last_obj, encode_obj
from dlt.common.typing import AnyFun

# file number of stdout (1) and stderr (2)
OutputStdStreamNo = Literal[1, 2]


@contextmanager
def exec_to_stdout(f: AnyFun) -> Iterator[Any]:
Expand All @@ -24,6 +28,47 @@ def exec_to_stdout(f: AnyFun) -> Iterator[Any]:
print(encode_obj(rv), flush=True)


def iter_std(
venv: Venv, command: str, *script_args: Any
) -> Iterator[Tuple[OutputStdStreamNo, str]]:
"""Starts a process `command` with `script_args` in environment `venv` and returns iterator
of (filno, line) tuples where `fileno` is 1 for stdout and 2 for stderr. `line` is
a content of a line with stripped new line character.
Use -u in scripts_args for unbuffered python execution
"""
with venv.start_command(
command, *script_args, stdout=PIPE, stderr=PIPE, bufsize=1, text=True
) as process:
exit_code: int = None
q_: queue.Queue[Tuple[OutputStdStreamNo, str]] = queue.Queue()

def _r_q(std_: OutputStdStreamNo) -> None:
stream_ = process.stderr if std_ == 2 else process.stdout
for line in iter(stream_.readline, ""):
q_.put((std_, line.rstrip("\n")))
# close queue
q_.put(None)

# read stderr with a thread, selectors do not work on windows
t_out = Thread(target=_r_q, args=(1,), daemon=True)
t_out.start()
t_err = Thread(target=_r_q, args=(2,), daemon=True)
t_err.start()
while line := q_.get():
yield line

# get exit code
exit_code = process.wait()
# wait till stderr is received
t_out.join()
t_err.join()

# we fail iterator if exit code is not 0
if exit_code != 0:
raise CalledProcessError(exit_code, command, output="", stderr="")


def iter_stdout(venv: Venv, command: str, *script_args: Any) -> Iterator[str]:
# start a process in virtual environment, assume that text comes from stdout
with venv.start_command(
Expand All @@ -44,10 +89,7 @@ def _r_stderr() -> None:

# read stdout with
for line in iter(process.stdout.readline, ""):
if line.endswith("\n"):
yield line[:-1]
else:
yield line
yield line.rstrip("\n")

# get exit code
exit_code = process.wait()
Expand Down
4 changes: 3 additions & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@
from typing import _TypedDict

REPattern = _REPattern[str]
PathLike = os.PathLike[str]
else:
StrOrBytesPath = Any
from typing import _TypedDictMeta as _TypedDict

REPattern = _REPattern
PathLike = os.PathLike

AnyType: TypeAlias = Any
NoneType = type(None)
Expand Down Expand Up @@ -92,7 +94,7 @@
TVariantBase = TypeVar("TVariantBase", covariant=True)
TVariantRV = Tuple[str, Any]
VARIANT_FIELD_FORMAT = "v_%s"
TFileOrPath = Union[str, os.PathLike, IO[Any]]
TFileOrPath = Union[str, PathLike, IO[Any]]
TSortOrder = Literal["asc", "desc"]


Expand Down
6 changes: 5 additions & 1 deletion dlt/destinations/impl/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.data_writers.escape import escape_athena_identifier
from dlt.common.data_writers.escape import (
escape_athena_identifier,
format_bigquery_datetime_literal,
)
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE


Expand All @@ -11,6 +14,7 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.preferred_staging_file_format = "parquet"
caps.supported_staging_file_formats = ["parquet", "jsonl"]
caps.escape_identifier = escape_athena_identifier
caps.format_datetime_literal = format_bigquery_datetime_literal
caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE)
caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0)
caps.max_identifier_length = 255
Expand Down
71 changes: 67 additions & 4 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@
from dlt.common.utils import without_none
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat
from dlt.common.schema.typing import (
TTableSchema,
TColumnType,
TWriteDisposition,
TTableFormat,
TSortOrder,
)
from dlt.common.schema.utils import table_schema_has_type, get_table_format
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob
from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination
from dlt.common.storages import FileStorage
from dlt.common.data_writers.escape import escape_bigquery_identifier
from dlt.destinations.sql_jobs import SqlStagingCopyJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob

from dlt.destinations.typing import DBApi, DBTransaction
from dlt.destinations.exceptions import (
Expand Down Expand Up @@ -154,6 +160,64 @@ def __init__(self) -> None:
DLTAthenaFormatter._INSTANCE = self


class AthenaMergeJob(SqlMergeJob):
@classmethod
def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str:
# reproducible name so we know which table to drop
with sql_client.with_staging_dataset(staging=True):
return sql_client.make_qualified_table_name(name_prefix)

@classmethod
def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
# regular table because Athena does not support temporary tables
return f"CREATE TABLE {temp_table_name} AS {select_sql};"

@classmethod
def gen_insert_temp_table_sql(
cls,
table_name: str,
staging_root_table_name: str,
sql_client: SqlClientBase[Any],
primary_keys: Sequence[str],
unique_column: str,
dedup_sort: Tuple[str, TSortOrder] = None,
condition: str = None,
condition_columns: Sequence[str] = None,
) -> Tuple[List[str], str]:
sql, temp_table_name = super().gen_insert_temp_table_sql(
table_name,
staging_root_table_name,
sql_client,
primary_keys,
unique_column,
dedup_sort,
condition,
condition_columns,
)
# DROP needs backtick as escape identifier
sql.insert(0, f"""DROP TABLE IF EXISTS {temp_table_name.replace('"', '`')};""")
return sql, temp_table_name

@classmethod
def gen_delete_temp_table_sql(
cls,
table_name: str,
unique_column: str,
key_table_clauses: Sequence[str],
sql_client: SqlClientBase[Any],
) -> Tuple[List[str], str]:
sql, temp_table_name = super().gen_delete_temp_table_sql(
table_name, unique_column, key_table_clauses, sql_client
)
# DROP needs backtick as escape identifier
sql.insert(0, f"""DROP TABLE IF EXISTS {temp_table_name.replace('"', '`')};""")
return sql, temp_table_name

@classmethod
def requires_temp_table_for_delete(cls) -> bool:
return True


class AthenaSQLClient(SqlClientBase[Connection]):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()
dbapi: ClassVar[DBApi] = pyathena
Expand Down Expand Up @@ -417,8 +481,7 @@ def _create_replace_followup_jobs(
return super()._create_replace_followup_jobs(table_chain)

def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
# fall back to append jobs for merge
return self._create_append_followup_jobs(table_chain)
return [AthenaMergeJob.from_table_chain(table_chain, self.sql_client)]

def _is_iceberg_table(self, table: TTableSchema) -> bool:
table_format = table.get("table_format")
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/dremio/pydremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, factory: CookieMiddlewareFactory, *args: Any, **kwargs: Any):
def received_headers(self, headers: Mapping[str, str]) -> None:
for key in headers:
if key.lower() == "set-cookie":
cookie = SimpleCookie() # type: ignore
cookie = SimpleCookie()
for item in headers.get(key):
cookie.load(item)

Expand Down
7 changes: 6 additions & 1 deletion dlt/destinations/impl/duckdb/configuration.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import dataclasses
import threading
from pathvalidate import is_valid_filepath

from typing import Any, ClassVar, Dict, Final, List, Optional, Tuple, Type, Union

from pathvalidate import is_valid_filepath
from dlt.common import logger
from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
from dlt.common.configuration.specs.exceptions import InvalidConnectionString
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.typing import TSecretValue
from dlt.destinations.impl.duckdb.exceptions import InvalidInMemoryDuckdbCredentials

try:
from duckdb import DuckDBPyConnection
Expand Down Expand Up @@ -117,6 +119,9 @@ def is_partial(self) -> bool:
return self.database == ":pipeline:"

def on_resolved(self) -> None:
if isinstance(self.database, str) and self.database == ":memory:":
raise InvalidInMemoryDuckdbCredentials()

# do not set any paths for external database
if self.database == ":external:":
return
Expand Down
11 changes: 11 additions & 0 deletions dlt/destinations/impl/duckdb/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dlt.common.destination.exceptions import DestinationTerminalException


class InvalidInMemoryDuckdbCredentials(DestinationTerminalException):
def __init__(self) -> None:
super().__init__(
"To use in-memory instance of duckdb, "
"please instantiate it first and then pass to destination factory\n"
'\nconn = duckdb.connect(":memory:")\n'
'dlt.pipeline(pipeline_name="...", destination=dlt.destinations.duckdb(conn)'
)
2 changes: 1 addition & 1 deletion dlt/destinations/impl/duckdb/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
Args:
credentials: Credentials to connect to the duckdb database. Can be an instance of `DuckDbCredentials` or
a path to a database file. Use `:memory:` to create an in-memory database or :pipeline: to create a duckdb
a path to a database file. Use :pipeline: to create a duckdb
in the working folder of the pipeline
create_indexes: Should unique indexes be created, defaults to False
**kwargs: Additional arguments passed to the destination config
Expand Down
22 changes: 15 additions & 7 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,18 @@ def gen_key_table_clauses(

@classmethod
def gen_delete_temp_table_sql(
cls, unique_column: str, key_table_clauses: Sequence[str], sql_client: SqlClientBase[Any]
cls,
table_name: str,
unique_column: str,
key_table_clauses: Sequence[str],
sql_client: SqlClientBase[Any],
) -> Tuple[List[str], str]:
"""Generate sql that creates delete temp table and inserts `unique_column` from root table for all records to delete. May return several statements.
Returns temp table name for cases where special names are required like SQLServer.
"""
sql: List[str] = []
temp_table_name = cls._new_temp_table_name("delete", sql_client)
temp_table_name = cls._new_temp_table_name("delete_" + table_name, sql_client)
select_statement = f"SELECT d.{unique_column} {key_table_clauses[0]}"
sql.append(cls._to_temp_table(select_statement, temp_table_name))
for clause in key_table_clauses[1:]:
Expand Down Expand Up @@ -281,6 +285,7 @@ def default_order_by(cls) -> str:
@classmethod
def gen_insert_temp_table_sql(
cls,
table_name: str,
staging_root_table_name: str,
sql_client: SqlClientBase[Any],
primary_keys: Sequence[str],
Expand All @@ -289,7 +294,7 @@ def gen_insert_temp_table_sql(
condition: str = None,
condition_columns: Sequence[str] = None,
) -> Tuple[List[str], str]:
temp_table_name = cls._new_temp_table_name("insert", sql_client)
temp_table_name = cls._new_temp_table_name("insert_" + table_name, sql_client)
if len(primary_keys) > 0:
# deduplicate
select_sql = cls.gen_select_from_dedup_sql(
Expand Down Expand Up @@ -417,7 +422,9 @@ def gen_merge_sql(
unique_column = escape_id(unique_columns[0])
# create temp table with unique identifier
create_delete_temp_table_sql, delete_temp_table_name = (
cls.gen_delete_temp_table_sql(unique_column, key_table_clauses, sql_client)
cls.gen_delete_temp_table_sql(
root_table["name"], unique_column, key_table_clauses, sql_client
)
)
sql.extend(create_delete_temp_table_sql)

Expand Down Expand Up @@ -470,6 +477,7 @@ def gen_merge_sql(
create_insert_temp_table_sql,
insert_temp_table_name,
) = cls.gen_insert_temp_table_sql(
root_table["name"],
staging_root_table_name,
sql_client,
primary_keys,
Expand Down Expand Up @@ -608,8 +616,8 @@ def gen_update_table_prefix(cls, table_name: str) -> str:

@classmethod
def requires_temp_table_for_delete(cls) -> bool:
"""this could also be a capabitiy, but probably it is better stored here
this identifies destinations that can have a simplified method for merging single
table table chains
"""Whether a temporary table is required to delete records.
Must be `True` for destinations that don't support correlated subqueries.
"""
return False
Loading

0 comments on commit c66bf9a

Please sign in to comment.