diff --git a/python/deltalake/licenses/LICENSE.txt b/python/deltalake/licenses/LICENSE.txt new file mode 100644 index 0000000000..06d01f6abf --- /dev/null +++ b/python/deltalake/licenses/LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (c) 2020 Ritchie Vink + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index e6854c3779..778c41a8b8 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, Tuple, Union +from typing import TYPE_CHECKING, Union import pyarrow as pa if TYPE_CHECKING: - import pandas as pd + pass from ._internal import ArrayType as ArrayType from ._internal import Field as Field @@ -17,34 +17,86 @@ DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] -def delta_arrow_schema_from_pandas( - data: "pd.DataFrame", -) -> Tuple[pa.Table, pa.Schema]: - """ - Infers the schema for the delta table from the Pandas DataFrame. - Necessary because of issues such as: https://github.com/delta-io/delta-rs/issues/686 - - Args: - data: Data to write. +### Inspired from Pola-rs repo - licensed with MIT License, see licenses folder.### +def _convert_pa_schema_to_delta( + schema: pa.schema, large_dtypes: bool = False +) -> pa.schema: + """Convert a PyArrow schema to a schema compatible with Delta Lake. Converts unsigned to signed equivalent, and + converts all timestamps to `us` timestamps. With the boolean flag large_dtypes you can control if the schema + should keep large types in the schema. - Returns: - A PyArrow Table and the inferred schema for the Delta Table + Args + schema: Source schema + large_dtypes: If True, the pyarrow schema is kept in large_dtypes """ + dtype_map = { + pa.uint8(): pa.int8(), + pa.uint16(): pa.int16(), + pa.uint32(): pa.int32(), + pa.uint64(): pa.int64(), + } + if not large_dtypes: + dtype_map = { + **dtype_map, + **{pa.large_string(): pa.string(), pa.large_binary(): pa.binary()}, + } + + def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType: + # Handle nested types + if isinstance(dtype, pa.LargeListType): + return list_to_delta_dtype(dtype) + elif isinstance(dtype, pa.StructType): + return struct_to_delta_dtype(dtype) + elif isinstance(dtype, pa.TimestampType): + return pa.timestamp("us") + try: + return dtype_map[dtype] + except KeyError: + return dtype - table = pa.Table.from_pandas(data) - schema = table.schema - schema_out = [] - for field in schema: - if isinstance(field.type, pa.TimestampType): - f = pa.field( - name=field.name, - type=pa.timestamp("us"), - nullable=field.nullable, - metadata=field.metadata, - ) - schema_out.append(f) + def list_to_delta_dtype( + dtype: pa.LargeListType, + ) -> Union[pa.LargeListType, pa.ListType]: + nested_dtype = dtype.value_type + nested_dtype_cast = dtype_to_delta_dtype(nested_dtype) + if large_dtypes: + return pa.large_list(nested_dtype_cast) else: - schema_out.append(field) - schema = pa.schema(schema_out, metadata=schema.metadata) - table = table.cast(target_schema=schema) - return table, schema + return pa.list_(nested_dtype_cast) + + def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType: + fields = [dtype.field(i) for i in range(dtype.num_fields)] + fields_cast = [pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in fields] + return pa.struct(fields_cast) + + return pa.schema([pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in schema]) + + +def convert_pyarrow_recordbatchreader( + data: pa.RecordBatchReader, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow RecordBatchReader to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + data = pa.RecordBatchReader.from_batches( + schema, + data.read_all().cast(schema).to_batches(), + ) + return data, schema + + +def convert_pyarrow_recordbatch( + data: pa.RecordBatch, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow RecordBatch to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + data = pa.Table.from_batches(data).cast(schema).to_batches() + return data, schema + + +def convert_pyarrow_table( + data: pa.RecordBatch, large_dtypes: bool +) -> pa.RecordBatchReader: + """Converts a PyArrow table to a PyArrow RecordBatchReader with a compatible delta schema""" + schema = _convert_pa_schema_to_delta(data.schema, large_dtypes=large_dtypes) + data = data.cast(schema).to_reader() + return data, schema diff --git a/python/deltalake/table.py b/python/deltalake/table.py index ad82a010fd..0a5e83c493 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -594,7 +594,12 @@ def optimize( def merge( self, - source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader], + source: Union[ + pyarrow.Table, + pyarrow.RecordBatch, + pyarrow.RecordBatchReader, + "pandas.DataFrame", + ], predicate: str, source_alias: Optional[str] = None, target_alias: Optional[str] = None, @@ -617,17 +622,29 @@ def merge( invariants = self.schema().invariants checker = _DeltaDataChecker(invariants) + from .schema import ( + convert_pyarrow_recordbatch, + convert_pyarrow_recordbatchreader, + convert_pyarrow_table, + ) + if isinstance(source, pyarrow.RecordBatchReader): - schema = source.schema + source, schema = convert_pyarrow_recordbatchreader( + source, large_dtypes=False + ) elif isinstance(source, pyarrow.RecordBatch): - schema = source.schema - source = [source] + source, schema = convert_pyarrow_recordbatch( + source, large_dtypes=False + ) # TODO(ion): set to True once MERGE uses logical plan elif isinstance(source, pyarrow.Table): - schema = source.schema - source = source.to_reader() + source, schema = convert_pyarrow_table(source, large_dtypes=False) + elif isinstance(source, pandas.DataFrame): + source, schema = convert_pyarrow_table( + pyarrow.Table.from_pandas(source), large_dtypes=False + ) else: raise TypeError( - f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source." + f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Table or Pandas DataFrame are valid inputs for source." ) def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index ef4ae3a57b..ff9819fd6a 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -34,8 +34,6 @@ import pyarrow.fs as pa_fs from pyarrow.lib import RecordBatchReader -from deltalake.schema import delta_arrow_schema_from_pandas - from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import batch_distinct from ._internal import write_new_deltalake as _write_new_deltalake @@ -158,13 +156,13 @@ def write_deltalake( overwrite_schema: If True, allows updating the schema of the table. storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined. partition_filters: the partition filters that will be used for partition overwrite. - large_dtypes: If True, the table schema is checked against large_dtypes + large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input """ - if _has_pandas and isinstance(data, pd.DataFrame): - if schema is not None: - data = pa.Table.from_pandas(data, schema=schema) - else: - data, schema = delta_arrow_schema_from_pandas(data) + from .schema import ( + convert_pyarrow_recordbatch, + convert_pyarrow_recordbatchreader, + convert_pyarrow_table, + ) table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) @@ -172,13 +170,26 @@ def write_deltalake( if table: table.update_incremental() - if schema is None: - if isinstance(data, RecordBatchReader): - schema = data.schema - elif isinstance(data, Iterable): - raise ValueError("You must provide schema if data is Iterable") + if isinstance(data, RecordBatchReader): + data, schema = convert_pyarrow_recordbatchreader(data, large_dtypes) + elif isinstance(data, pa.RecordBatch): + data, schema = convert_pyarrow_recordbatch(data, large_dtypes) + elif isinstance(data, pa.Table): + data, schema = convert_pyarrow_table(data, large_dtypes) + elif isinstance(data, ds.Dataset): + data, schema = convert_pyarrow_table(data.to_table(), large_dtypes) + elif _has_pandas and isinstance(data, pd.DataFrame): + if schema is not None: + data = pa.Table.from_pandas(data, schema=schema) else: - schema = data.schema + data, schema = convert_pyarrow_table(pa.Table.from_pandas(data), False) + elif isinstance(data, Iterable): + if schema is None: + raise ValueError("You must provide schema if data is Iterable") + else: + raise TypeError( + f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source." + ) if filesystem is not None: raise NotImplementedError("Filesystem support is not yet implemented. #570") @@ -225,7 +236,7 @@ def write_deltalake( current_version = -1 dtype_map = { - pa.large_string(): pa.string(), # type: ignore + pa.large_string(): pa.string(), } def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: @@ -327,19 +338,8 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: return batch - if isinstance(data, RecordBatchReader): - batch_iter = data - elif isinstance(data, pa.RecordBatch): - batch_iter = [data] - elif isinstance(data, pa.Table): - batch_iter = data.to_batches() - elif isinstance(data, ds.Dataset): - batch_iter = data.to_batches() - else: - batch_iter = data - data = RecordBatchReader.from_batches( - schema, (validate_batch(batch) for batch in batch_iter) + schema, (validate_batch(batch) for batch in data) ) if file_options is not None: diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index f8c9d152aa..10edfcf663 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -19,10 +19,23 @@ type_for_alias: Any date32: Any date64: Any decimal128: Any +int8: Any +int16: Any int32: Any +int64: Any +uint8: Any +uint16: Any +uint32: Any +uint64: Any float16: Any float32: Any float64: Any +large_string: Any +string: Any +large_binary: Any +binary: Any +large_list: Any +LargeListType: Any dictionary: Any timestamp: Any TimestampType: Any