Skip to content

Commit

Permalink
finish first batch of helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 12, 2024
1 parent cab5dcc commit f609fc8
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 62 deletions.
10 changes: 9 additions & 1 deletion dlt/common/data_types/type_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum

from dlt.common import pendulum, json, Decimal, Wei
from dlt.common.json import custom_pua_remove
from dlt.common.json import custom_pua_remove, json
from dlt.common.json._simplejson import custom_encode as json_custom_encode
from dlt.common.arithmetics import InvalidOperation
from dlt.common.data_types.typing import TDataType
Expand Down Expand Up @@ -105,6 +105,14 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any:
return int(value.value)
return value

if to_type == "complex":
# try to coerce from text
if from_type == "text":
try:
return json.loads(value)
except Exception as e:
pass

if to_type == "text":
if from_type == "complex":
return complex_to_str(value)
Expand Down
106 changes: 63 additions & 43 deletions dlt/destinations/impl/sink/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@

class SinkLoadJob(LoadJob, FollowupJob):
def __init__(
self, table: TTableSchema, file_path: str, config: SinkClientConfiguration
self, table: TTableSchema, file_path: str, config: SinkClientConfiguration, schema: Schema
) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
self._file_path = file_path
self._config = config
self._table = table
self._schema = schema
self.run()

def run(self) -> None:
Expand All @@ -41,10 +42,19 @@ def run(self) -> None:
def call_callable_with_items(self, items: TDataItems) -> None:
if not items:
return
if self._config.credentials.callable:
self._config.credentials.callable(
items[0] if self._config.batch_size == 1 else items, self._table
)

# coerce items into correct format specified by schema
coerced_items: TDataItems = []
for item in items:
coerced_item, table_update = self._schema.coerce_row(self._table["name"], None, item)
assert not table_update
coerced_items.append(coerced_item)

# send single item on batch size 1
if self._config.batch_size == 1:
coerced_items = coerced_items[0]

self._config.credentials.callable(coerced_items, self._table)

def state(self) -> TLoadJobState:
return "completed"
Expand Down Expand Up @@ -79,40 +89,50 @@ def run(self) -> None:
self.call_callable_with_items(current_batch)


class SinkInsertValueslLoadJob(SinkLoadJob):
def run(self) -> None:
from dlt.common import json

# stream items
with FileStorage.open_zipsafe_ro(self._file_path) as f:
current_batch: TDataItems = []
column_names: List[str] = []
for line in f:
line = line.strip()

# TODO respect inserts with multiline values

# extract column names
if line.startswith("INSERT INTO") and line.endswith(")"):
line = line[15:-1]
column_names = line.split(",")
continue

# not a valid values line
if not line.startswith("(") or not line.endswith(");"):
continue

# extract values
line = line[1:-2]
values = line.split(",")

# zip and send to callable
current_batch.append(dict(zip(column_names, values)))
if len(current_batch) == self._config.batch_size:
self.call_callable_with_items(current_batch)
current_batch = []

self.call_callable_with_items(current_batch)
# class SinkInsertValueslLoadJob(SinkLoadJob):
# def run(self) -> None:
# from dlt.common import json

# # stream items
# with FileStorage.open_zipsafe_ro(self._file_path) as f:
# header = f.readline().strip()
# values_mark = f.readline()

# # properly formatted file has a values marker at the beginning
# assert values_mark == "VALUES\n"

# # extract column names
# assert header.startswith("INSERT INTO") and header.endswith(")")
# header = header[15:-1]
# column_names = header.split(",")

# # build batches
# current_batch: TDataItems = []
# current_row: str = ""
# for line in f:
# current_row += line
# if line.endswith(");"):
# current_row = current_row[1:-2]
# elif line.endswith("),\n"):
# current_row = current_row[1:-3]
# else:
# continue

# values = current_row.split(",")
# values = [None if v == "NULL" else v for v in values]
# current_row = ""
# print(values)
# print(current_row)

# # zip and send to callable
# current_batch.append(dict(zip(column_names, values)))
# d = dict(zip(column_names, values))
# print(json.dumps(d, pretty=True))
# if len(current_batch) == self._config.batch_size:
# self.call_callable_with_items(current_batch)
# current_batch = []

# self.call_callable_with_items(current_batch)


class SinkClient(JobClientBase):
Expand Down Expand Up @@ -140,11 +160,11 @@ def update_stored_schema(

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
if file_path.endswith("parquet"):
return SinkParquetLoadJob(table, file_path, self.config)
return SinkParquetLoadJob(table, file_path, self.config, self.schema)
if file_path.endswith("jsonl"):
return SinkJsonlLoadJob(table, file_path, self.config)
if file_path.endswith("insert_values"):
return SinkInsertValueslLoadJob(table, file_path, self.config)
return SinkJsonlLoadJob(table, file_path, self.config, self.schema)
# if file_path.endswith("insert_values"):
# return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema)
return EmptyLoadJob.from_file_path(file_path, "completed")

def restore_file_load(self, file_path: str) -> LoadJob:
Expand Down
25 changes: 19 additions & 6 deletions tests/cases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Any, Sequence, Tuple, Literal
from typing import Dict, List, Any, Sequence, Tuple, Literal, Union
import base64
from hexbytes import HexBytes
from copy import deepcopy
Expand All @@ -7,7 +7,7 @@

from dlt.common import Decimal, pendulum, json
from dlt.common.data_types import TDataType
from dlt.common.typing import StrAny
from dlt.common.typing import StrAny, TDataItems
from dlt.common.wei import Wei
from dlt.common.time import (
ensure_pendulum_datetime,
Expand Down Expand Up @@ -161,18 +161,23 @@ def table_update_and_row(


def assert_all_data_types_row(
db_row: List[Any],
db_row: Union[List[Any], TDataItems],
parse_complex_strings: bool = False,
allow_base64_binary: bool = False,
timestamp_precision: int = 6,
schema: TTableSchemaColumns = None,
expect_filtered_null_columns=False,
) -> None:
# content must equal
# print(db_row)
schema = schema or TABLE_UPDATE_COLUMNS_SCHEMA

# Include only columns requested in schema
db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)}
if isinstance(db_row, dict):
db_mapping = db_row.copy()
else:
db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)}

expected_rows = {key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema}
# prepare date to be compared: convert into pendulum instance, adjust microsecond precision
if "col4" in expected_rows:
Expand Down Expand Up @@ -226,8 +231,16 @@ def assert_all_data_types_row(
if "col11" in db_mapping:
db_mapping["col11"] = db_mapping["col11"].isoformat()

for expected, actual in zip(expected_rows.values(), db_mapping.values()):
assert expected == actual
if expect_filtered_null_columns:
for key, expected in expected_rows.items():
if expected is None:
assert db_mapping.get(key, None) == None
db_mapping[key] = None

for key, expected in expected_rows.items():
actual = db_mapping[key]
assert expected == actual, f"Expected {expected} but got {actual} for column {key}"

assert db_mapping == expected_rows


Expand Down
59 changes: 47 additions & 12 deletions tests/load/sink/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
delete_dataset,
)

SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl", "insert_values"]
SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl"]


def _run_through_sink(
items: TDataItems,
loader_file_format: TLoaderFileFormat,
columns=None,
filter_dlt_tables: bool = True,
batch_size: int = 10,
) -> List[Tuple[TDataItems, TTableSchema]]:
"""
runs a list of items through the sink destination and returns colleceted calls
"""
calls: List[Tuple[TDataItems, TTableSchema]] = []

@dlt.sink(loader_file_format=loader_file_format, batch_size=1)
@dlt.sink(loader_file_format=loader_file_format, batch_size=batch_size)
def test_sink(items: TDataItems, table: TTableSchema) -> None:
nonlocal calls
if table["name"].startswith("_dlt") and filter_dlt_tables:
Expand All @@ -52,26 +53,60 @@ def test_all_datatypes(loader_file_format: TLoaderFileFormat) -> None:
data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES)
column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA)

sink_calls = _run_through_sink(data_types, loader_file_format, columns=column_schemas)
sink_calls = _run_through_sink(
[data_types, data_types, data_types],
loader_file_format,
columns=column_schemas,
batch_size=1,
)

# inspect result
assert len(sink_calls) == 1
assert len(sink_calls) == 3

item = sink_calls[0][0]
# filter out _dlt columns
item = {k: v for k, v in item.items() if not k.startswith("_dlt")}
item = {k: v for k, v in item.items() if not k.startswith("_dlt")} # type: ignore

# null values are not saved in jsonl (TODO: is this correct?)
if loader_file_format == "jsonl":
data_types = {k: v for k, v in data_types.items() if v is not None}
# null values are not emitted
data_types = {k: v for k, v in data_types.items() if v is not None}

# check keys are the same
assert set(item.keys()) == set(data_types.keys())

# TODO: check actual types
# assert_all_data_types_row
assert_all_data_types_row(item, expect_filtered_null_columns=True)


@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS)
def test_batch_size(loader_file_format: TLoaderFileFormat) -> None:
pass
@pytest.mark.parametrize("batch_size", [1, 10, 23])
def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> None:
items = [{"id": i, "value": str(i)} for i in range(100)]

sink_calls = _run_through_sink(items, loader_file_format, batch_size=batch_size)

if batch_size == 1:
assert len(sink_calls) == 100
# one item per call
assert sink_calls[0][0].items() > {"id": 0, "value": "0"}.items() # type: ignore
elif batch_size == 10:
assert len(sink_calls) == 10
# ten items in first call
assert len(sink_calls[0][0]) == 10
assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items()
elif batch_size == 23:
assert len(sink_calls) == 5
# 23 items in first call
assert len(sink_calls[0][0]) == 23
assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items()

# check all items are present
all_items = set()
for call in sink_calls:
item = call[0]
if batch_size == 1:
item = [item]
for entry in item:
all_items.add(entry["value"])

assert len(all_items) == 100
for i in range(100):
assert str(i) in all_items

0 comments on commit f609fc8

Please sign in to comment.