Skip to content

Commit

Permalink
port schema conversion from polars and update writer and merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Nov 7, 2023
1 parent b3f478e commit c144c65
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 63 deletions.
19 changes: 19 additions & 0 deletions python/deltalake/licenses/LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Copyright (c) 2020 Ritchie Vink

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
110 changes: 81 additions & 29 deletions python/deltalake/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING, Tuple, Union
from typing import TYPE_CHECKING, Union

import pyarrow as pa

if TYPE_CHECKING:
import pandas as pd
pass

from ._internal import ArrayType as ArrayType
from ._internal import Field as Field
Expand All @@ -17,34 +17,86 @@
DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"]


def delta_arrow_schema_from_pandas(
data: "pd.DataFrame",
) -> Tuple[pa.Table, pa.Schema]:
"""
Infers the schema for the delta table from the Pandas DataFrame.
Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686
Args:
data: Data to write.
### Inspired from Pola-rs repo - licensed with MIT License, see licenses folder.###
def _convert_pa_schema_to_delta(
schema: pa.schema, large_dtypes: bool = False
) -> pa.schema:
"""Convert a PyArrow schema to a schema compatible with Delta Lake. Converts unsigned to signed equivalent, and
converts all timestamps to `us` timestamps. With the boolean flag large_dtypes you can control if the schema
should keep large types in the schema.
Returns:
A PyArrow Table and the inferred schema for the Delta Table
Args
schema: Source schema
large_dtypes: If True, the pyarrow schema is kept in large_dtypes
"""
dtype_map = {
pa.uint8(): pa.int8(),
pa.uint16(): pa.int16(),
pa.uint32(): pa.int32(),
pa.uint64(): pa.int64(),
}
if not large_dtypes:
dtype_map = {
**dtype_map,
**{pa.large_string(): pa.string(), pa.large_binary(): pa.binary()},
}

def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType:
# Handle nested types
if isinstance(dtype, pa.LargeListType):
return list_to_delta_dtype(dtype)
elif isinstance(dtype, pa.StructType):
return struct_to_delta_dtype(dtype)
elif isinstance(dtype, pa.TimestampType):
return pa.timestamp("us")
try:
return dtype_map[dtype]
except KeyError:
return dtype

table = pa.Table.from_pandas(data)
schema = table.schema
schema_out = []
for field in schema:
if isinstance(field.type, pa.TimestampType):
f = pa.field(
name=field.name,
type=pa.timestamp("us"),
nullable=field.nullable,
metadata=field.metadata,
)
schema_out.append(f)
def list_to_delta_dtype(
dtype: pa.LargeListType,
) -> Union[pa.LargeListType, pa.ListType]:
nested_dtype = dtype.value_type
nested_dtype_cast = dtype_to_delta_dtype(nested_dtype)
if large_dtypes:
return pa.large_list(nested_dtype_cast)
else:
schema_out.append(field)
schema = pa.schema(schema_out, metadata=schema.metadata)
table = table.cast(target_schema=schema)
return table, schema
return pa.list_(nested_dtype_cast)

def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType:
fields = [dtype.field(i) for i in range(dtype.num_fields)]
fields_cast = [pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in fields]
return pa.struct(fields_cast)

return pa.schema([pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in schema])


def convert_pyarrow_recordbatchreader(
data: pa.RecordBatchReader, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow RecordBatchReader to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = pa.RecordBatchReader.from_batches(
schema,
data.read_all().cast(schema).to_batches(),
)
return data, schema


def convert_pyarrow_recordbatch(
data: pa.RecordBatch, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow RecordBatch to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = pa.Table.from_batches(data).cast(schema).to_batches()
return data, schema


def convert_pyarrow_table(
data: pa.RecordBatch, large_dtypes: bool
) -> pa.RecordBatchReader:
"""Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema"""
schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes)
data = data.cast(schema).to_reader()
return data, schema
31 changes: 24 additions & 7 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,12 @@ def optimize(

def merge(
self,
source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader],
source: Union[
pyarrow.Table,
pyarrow.RecordBatch,
pyarrow.RecordBatchReader,
"pandas.DataFrame",
],
predicate: str,
source_alias: Optional[str] = None,
target_alias: Optional[str] = None,
Expand All @@ -617,17 +622,29 @@ def merge(
invariants = self.schema().invariants
checker = _DeltaDataChecker(invariants)

from .schema import (
convert_pyarrow_recordbatch,
convert_pyarrow_recordbatchreader,
convert_pyarrow_table,
)

if isinstance(source, pyarrow.RecordBatchReader):
schema = source.schema
source, schema = convert_pyarrow_recordbatchreader(
source, large_dtypes=False
)
elif isinstance(source, pyarrow.RecordBatch):
schema = source.schema
source = [source]
source, schema = convert_pyarrow_recordbatch(
source, large_dtypes=False
) # TODO(ion): set to True once MERGE uses logical plan
elif isinstance(source, pyarrow.Table):
schema = source.schema
source = source.to_reader()
source, schema = convert_pyarrow_table(source, large_dtypes=False)
elif isinstance(source, pandas.DataFrame):
source, schema = convert_pyarrow_table(
pyarrow.Table.from_pandas(source), large_dtypes=False
)
else:
raise TypeError(
f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source."
f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Table or Pandas DataFrame are valid inputs for source."
)

def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch:
Expand Down
54 changes: 27 additions & 27 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader

from deltalake.schema import delta_arrow_schema_from_pandas

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import batch_distinct
from ._internal import write_new_deltalake as _write_new_deltalake
Expand Down Expand Up @@ -158,27 +156,40 @@ def write_deltalake(
overwrite_schema: If True, allows updating the schema of the table.
storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined.
partition_filters: the partition filters that will be used for partition overwrite.
large_dtypes: If True, the table schema is checked against large_dtypes
large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input
"""
if _has_pandas and isinstance(data, pd.DataFrame):
if schema is not None:
data = pa.Table.from_pandas(data, schema=schema)
else:
data, schema = delta_arrow_schema_from_pandas(data)
from .schema import (
convert_pyarrow_recordbatch,
convert_pyarrow_recordbatchreader,
convert_pyarrow_table,
)

table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)

# We need to write against the latest table version
if table:
table.update_incremental()

if schema is None:
if isinstance(data, RecordBatchReader):
schema = data.schema
elif isinstance(data, Iterable):
raise ValueError("You must provide schema if data is Iterable")
if isinstance(data, RecordBatchReader):
data, schema = convert_pyarrow_recordbatchreader(data, large_dtypes)
elif isinstance(data, pa.RecordBatch):
data, schema = convert_pyarrow_recordbatch(data, large_dtypes)
elif isinstance(data, pa.Table):
data, schema = convert_pyarrow_table(data, large_dtypes)
elif isinstance(data, ds.Dataset):
data, schema = convert_pyarrow_table(data.to_table(), large_dtypes)
elif _has_pandas and isinstance(data, pd.DataFrame):
if schema is not None:
data = pa.Table.from_pandas(data, schema=schema)
else:
schema = data.schema
data, schema = convert_pyarrow_table(pa.Table.from_pandas(data), False)
elif isinstance(data, Iterable):
if schema is None:
raise ValueError("You must provide schema if data is Iterable")
else:
raise TypeError(
f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source."
)

if filesystem is not None:
raise NotImplementedError("Filesystem support is not yet implemented. #570")
Expand Down Expand Up @@ -225,7 +236,7 @@ def write_deltalake(
current_version = -1

dtype_map = {
pa.large_string(): pa.string(), # type: ignore
pa.large_string(): pa.string(),
}

def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType:
Expand Down Expand Up @@ -327,19 +338,8 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:

return batch

if isinstance(data, RecordBatchReader):
batch_iter = data
elif isinstance(data, pa.RecordBatch):
batch_iter = [data]
elif isinstance(data, pa.Table):
batch_iter = data.to_batches()
elif isinstance(data, ds.Dataset):
batch_iter = data.to_batches()
else:
batch_iter = data

data = RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in batch_iter)
schema, (validate_batch(batch) for batch in data)
)

if file_options is not None:
Expand Down
13 changes: 13 additions & 0 deletions python/stubs/pyarrow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,23 @@ type_for_alias: Any
date32: Any
date64: Any
decimal128: Any
int8: Any
int16: Any
int32: Any
int64: Any
uint8: Any
uint16: Any
uint32: Any
uint64: Any
float16: Any
float32: Any
float64: Any
large_string: Any
string: Any
large_binary: Any
binary: Any
large_list: Any
LargeListType: Any
dictionary: Any
timestamp: Any
TimestampType: Any
Expand Down

0 comments on commit c144c65

Please sign in to comment.