Skip to content

Commit

Permalink
arrow example (#707)
Browse files Browse the repository at this point in the history
* sends region name to s3 fsspec client

* shows warning in config missing exception when pipeline script not in working folder

* counts rows of arrow tables in writers to rotate files properly

* adds arrow + connector x example

* bumps to mypy 1.6.1 and fixes transformer decorator

* explains variant column and other fixes in docs

* reuses primary key for index in incremental

* removes unix ts autodetect by default, add/remove detects in schema

* passes column schema to arrow writer

* fixes tests

* adds blog post

---------

Co-authored-by: Adrian <Adrian>
  • Loading branch information
rudolfix authored Oct 24, 2023
1 parent 024cd4d commit 7325c4f
Show file tree
Hide file tree
Showing 37 changed files with 630 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_doc_snippets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:

- name: Install dependencies
# if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction -E duckdb -E weaviate --with docs --without airflow
run: poetry install --no-interaction -E duckdb -E weaviate -E parquet --with docs --without airflow

- name: Run linter and tests
run: make test-and-lint-snippets
Expand Down
10 changes: 10 additions & 0 deletions dlt/common/configuration/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Any, Mapping, Type, Tuple, NamedTuple, Sequence

from dlt.common.exceptions import DltException, TerminalException
from dlt.common.utils import main_module_file_path


class LookupTrace(NamedTuple):
Expand Down Expand Up @@ -48,6 +50,14 @@ def __str__(self) -> str:
msg += f'\tfor field "{f}" config providers and keys were tried in following order:\n'
for tr in field_traces:
msg += f'\t\tIn {tr.provider} key {tr.key} was not found.\n'
# check if entry point is run with path. this is common problem so warn the user
main_path = main_module_file_path()
main_dir = os.path.dirname(main_path)
abs_main_dir = os.path.abspath(main_dir)
if abs_main_dir != os.getcwd():
# directory was specified
msg += "WARNING: dlt looks for .dlt folder in your current working directory and your cwd (%s) is different from directory of your pipeline script (%s).\n" % (os.getcwd(), abs_main_dir)
msg += "If you keep your secret files in the same folder as your pipeline script but run your script from some other folder, secrets/configs will not be found\n"
msg += "Please refer to https://dlthub.com/docs/general-usage/credentials for more information\n"
return msg

Expand Down
7 changes: 5 additions & 2 deletions dlt/common/configuration/specs/aws_credentials.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Dict, Any

from dlt.common.exceptions import MissingDependencyException
from dlt.common.typing import TSecretStrValue
from dlt.common.typing import TSecretStrValue, DictStrAny
from dlt.common.configuration.specs import CredentialsConfiguration, CredentialsWithDefault, configspec
from dlt.common.configuration.specs.exceptions import InvalidBoto3Session
from dlt import version
Expand All @@ -19,13 +19,16 @@ class AwsCredentialsWithoutDefaults(CredentialsConfiguration):

def to_s3fs_credentials(self) -> Dict[str, Optional[str]]:
"""Dict of keyword arguments that can be passed to s3fs"""
return dict(
credentials: DictStrAny = dict(
key=self.aws_access_key_id,
secret=self.aws_secret_access_key,
token=self.aws_session_token,
profile=self.profile_name,
endpoint_url=self.endpoint_url,
)
if self.region_name:
credentials["client_kwargs"] = {"region_name": self.region_name}
return credentials

def to_native_representation(self) -> Dict[str, Optional[str]]:
"""Return a dict that can be passed as kwargs to boto3 session"""
Expand Down
18 changes: 16 additions & 2 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gzip
from functools import reduce
from typing import List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.utils import uniq_id
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self._current_columns: TTableSchemaColumns = None
self._file_name: str = None
self._buffered_items: List[TDataItem] = []
self._buffered_items_count: int = 0
self._writer: TWriter = None
self._file: IO[Any] = None
self._closed = False
Expand All @@ -79,10 +81,20 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> Non
if isinstance(item, List):
# items coming in single list will be written together, not matter how many are there
self._buffered_items.extend(item)
# update row count, if item supports "num_rows" it will be used to count items
if len(item) > 0 and hasattr(item[0], "num_rows"):
self._buffered_items_count += sum(tbl.num_rows for tbl in item)
else:
self._buffered_items_count += len(item)
else:
self._buffered_items.append(item)
# update row count, if item supports "num_rows" it will be used to count items
if hasattr(item, "num_rows"):
self._buffered_items_count += item.num_rows
else:
self._buffered_items_count += 1
# flush if max buffer exceeded
if len(self._buffered_items) >= self.buffer_max_items:
if self._buffered_items_count >= self.buffer_max_items:
self._flush_items()
# rotate the file if max_bytes exceeded
if self._file:
Expand Down Expand Up @@ -118,7 +130,7 @@ def _rotate_file(self) -> None:
self._file_name = self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension

def _flush_items(self, allow_empty_file: bool = False) -> None:
if len(self._buffered_items) > 0 or allow_empty_file:
if self._buffered_items_count > 0 or allow_empty_file:
# we only open a writer when there are any items in the buffer and first flush is requested
if not self._writer:
# create new writer and write header
Expand All @@ -131,7 +143,9 @@ def _flush_items(self, allow_empty_file: bool = False) -> None:
# write buffer
if self._buffered_items:
self._writer.write_data(self._buffered_items)
# reset buffer and counter
self._buffered_items.clear()
self._buffered_items_count = 0

def _flush_and_close_file(self) -> None:
# if any buffered items exist, flush them
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def write_data(self, rows: Sequence[Any]) -> None:
self.writer.write_batch(row)
else:
raise ValueError(f"Unsupported type {type(row)}")
# count rows that got written
self.items_count += row.num_rows

@classmethod
def data_format(cls) -> TFileFormatSpec:
Expand Down
12 changes: 12 additions & 0 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,18 @@ def update_normalizers(self) -> None:
normalizers["json"] = normalizers["json"] or self._normalizers_config["json"]
self._configure_normalizers(normalizers)

def add_type_detection(self, detection: TTypeDetections) -> None:
"""Add type auto detection to the schema."""
if detection not in self.settings["detections"]:
self.settings["detections"].append(detection)
self._compile_settings()

def remove_type_detection(self, detection: TTypeDetections) -> None:
"""Adds type auto detection to the schema."""
if detection in self.settings["detections"]:
self.settings["detections"].remove(detection)
self._compile_settings()

def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False) -> TColumnSchema:
column_schema = TColumnSchema(
name=k,
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,4 @@ def standard_hints() -> Dict[TColumnHint, List[TSimpleRegex]]:


def standard_type_detections() -> List[TTypeDetections]:
return ["timestamp", "iso_timestamp"]
return ["iso_timestamp"]
2 changes: 1 addition & 1 deletion dlt/destinations/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
for search_prefix in truncate_prefixes:
if item.startswith(search_prefix):
# NOTE: deleting in chunks on s3 does not raise on access denied, file non existing and probably other errors
logger.info(f"DEL {item}")
# logger.info(f"DEL {item}")
# print(f"DEL {item}")
self.fs_client.rm(item)
except FileNotFoundError:
Expand Down
4 changes: 2 additions & 2 deletions dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunPara
spec=spec, sections=resource_sections, sections_merge_style=ConfigSectionContext.resource_merge_style, include_defaults=spec is not None
)
is_inner_resource = is_inner_callable(f)
if conf_f != incr_f and is_inner_resource:
if conf_f != incr_f and is_inner_resource and not standalone:
raise ResourceInnerCallableConfigWrapDisallowed(resource_name, source_section)
# get spec for wrapped function
SPEC = get_fun_spec(conf_f)
Expand Down Expand Up @@ -494,7 +494,7 @@ def transformer(
selected: bool = True,
spec: Type[BaseConfiguration] = None,
standalone: Literal[True] = True
) -> Callable[..., DltResource]: # TODO: change back to Callable[TResourceFunParams, DltResource] when mypy 1.6 is fixed
) -> Callable[TResourceFunParams, DltResource]: # TODO: change back to Callable[TResourceFunParams, DltResource] when mypy 1.6 is fixed
...

def transformer(
Expand Down
7 changes: 5 additions & 2 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def write_empty_file(self, table_name: str) -> None:
table_name = self.schema.naming.normalize_table_identifier(table_name)
self.storage.write_empty_file(self.extract_id, self.schema.name, table_name, None)

def _write_item(self, table_name: str, resource_name: str, items: TDataItems) -> None:
def _write_item(self, table_name: str, resource_name: str, items: TDataItems, columns: TTableSchemaColumns = None) -> None:
# normalize table name before writing so the name match the name in schema
# note: normalize function should be cached so there's almost no penalty on frequent calling
# note: column schema is not required for jsonl writer used here
table_name = self.schema.naming.normalize_identifier(table_name)
self.collector.update(table_name)
self.resources_with_items.add(resource_name)
self.storage.write_data_item(self.extract_id, self.schema.name, table_name, items, None)
self.storage.write_data_item(self.extract_id, self.schema.name, table_name, items, columns)

def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None:
table_name = resource._table_name_hint_fun(item)
Expand Down Expand Up @@ -212,6 +212,9 @@ def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> No
]
super().write_table(resource, items, meta)

def _write_item(self, table_name: str, resource_name: str, items: TDataItems, columns: TTableSchemaColumns = None) -> None:
super()._write_item(table_name, resource_name, items, self.dynamic_tables[table_name][0]["columns"])

def _write_static_table(self, resource: DltResource, table_name: str, items: TDataItems) -> None:
existing_table = self.dynamic_tables.get(table_name)
if existing_table is not None:
Expand Down
87 changes: 54 additions & 33 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime # noqa: I251
from typing import Optional, Tuple, Protocol, Mapping, Union, List
from datetime import datetime, date # noqa: I251
from typing import Optional, Tuple, List

try:
import pandas as pd
Expand Down Expand Up @@ -137,8 +137,9 @@ def __call__(
return row, start_out_of_range, end_out_of_range



class ArrowIncremental(IncrementalTransformer):
_dlt_index = "_dlt_index"

def unique_values(
self,
item: "TAnyArrowItem",
Expand All @@ -148,28 +149,34 @@ def unique_values(
if not unique_columns:
return []
item = item
indices = item["_dlt_index"].to_pylist()
indices = item[self._dlt_index].to_pylist()
rows = item.select(unique_columns).to_pylist()
return [
(index, digest128(json.dumps(row, sort_keys=True))) for index, row in zip(indices, rows)
]

def _deduplicate(self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str) -> "pa.Table":
if unique_columns is None:
return tbl
group_cols = unique_columns + [cursor_path]
tbl = tbl.append_column("_dlt_index", pa.array(np.arange(tbl.num_rows)))
try:
tbl = tbl.filter(
pa.compute.is_in(
tbl['_dlt_index'],
tbl.group_by(group_cols).aggregate(
[("_dlt_index", "one"), (cursor_path, aggregate)]
)['_dlt_index_one']
)
)
except KeyError as e:
raise IncrementalPrimaryKeyMissing(self.resource_name, unique_columns[0], tbl) from e
def _deduplicate(self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str) -> "pa.Table":
"""Creates unique index if necessary."""
# create unique index if necessary
if self._dlt_index not in tbl.schema.names:
tbl = tbl.append_column(self._dlt_index, pa.array(np.arange(tbl.num_rows)))
# code below deduplicates groups that include the cursor column in the group id. that was just artifact of
# json incremental and there's no need to duplicate it here

# if unique_columns is None:
# return tbl
# group_cols = unique_columns + [cursor_path]
# try:
# tbl = tbl.filter(
# pa.compute.is_in(
# tbl[self._dlt_index],
# tbl.group_by(group_cols).aggregate(
# [(self._dlt_index, "one"), (cursor_path, aggregate)]
# )[f'{self._dlt_index}_one']
# )
# )
# except KeyError as e:
# raise IncrementalPrimaryKeyMissing(self.resource_name, unique_columns[0], tbl) from e
return tbl

def __call__(
Expand All @@ -180,6 +187,25 @@ def __call__(
if is_pandas:
tbl = pa.Table.from_pandas(tbl)

primary_key = self.primary_key(tbl) if callable(self.primary_key) else self.primary_key
if primary_key:
# create a list of unique columns
if isinstance(primary_key, str):
unique_columns = [primary_key]
else:
unique_columns = list(primary_key)
# check if primary key components are in the table
for pk in unique_columns:
if pk not in tbl.schema.names:
raise IncrementalPrimaryKeyMissing(self.resource_name, pk, tbl)
# use primary key as unique index
if isinstance(primary_key, str):
self._dlt_index = primary_key
elif primary_key is None:
unique_columns = tbl.column_names
else: # deduplicating is disabled
unique_columns = None

start_out_of_range = end_out_of_range = False
if not tbl: # row is None or empty arrow table
return tbl, start_out_of_range, end_out_of_range
Expand All @@ -206,24 +232,19 @@ def __call__(
cursor_path = str(self.cursor_path)
# The new max/min value
try:
row_value = compute(tbl[cursor_path]).as_py()
orig_row_value = compute(tbl[cursor_path])
row_value = orig_row_value.as_py()
# dates are not represented as datetimes but I see connector-x represents
# datetimes as dates and keeping the exact time inside. probably a bug
# but can be corrected this way
if isinstance(row_value, date) and not isinstance(row_value, datetime):
row_value = pendulum.from_timestamp(orig_row_value.cast(pa.int64()).as_py() / 1000)
except KeyError as e:
raise IncrementalCursorPathMissing(
self.resource_name, cursor_path, tbl,
f"Column name {str(cursor_path)} was not found in the arrow table. Note nested JSON paths are not supported for arrow tables and dataframes, the incremental cursor_path must be a column name."
) from e

primary_key = self.primary_key(tbl) if callable(self.primary_key) else self.primary_key
if primary_key:
if isinstance(primary_key, str):
unique_columns = [primary_key]
else:
unique_columns = list(primary_key)
elif primary_key is None:
unique_columns = tbl.column_names
else: # deduplicating is disabled
unique_columns = None

# If end_value is provided, filter to include table rows that are "less" than end_value
if self.end_value is not None:
tbl = tbl.filter(end_compare(tbl[cursor_path], self.end_value))
Expand All @@ -247,7 +268,7 @@ def __call__(
unique_values = [(i, uq_val) for i, uq_val in unique_values if uq_val in self.incremental_state['unique_hashes']]
remove_idx = pa.array(i for i, _ in unique_values)
# Filter the table
tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl["_dlt_index"], remove_idx)))
tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)))

if new_value_compare(row_value, last_value).as_py() and row_value != last_value: # Last value has changed
self.incremental_state['last_value'] = row_value
Expand Down
File renamed without changes.
Loading

0 comments on commit 7325c4f

Please sign in to comment.