From c35e82962f3fa3caf4f2337e4df009a6d76bf0d8 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Fri, 5 Apr 2024 15:19:24 +0200 Subject: [PATCH 1/5] get_inferred_type calls pa.array(self) which is the same as pa.array(types_sequence) --- src/datasets/arrow_writer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 82e72a91ecc..08fb680ce65 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -561,8 +561,9 @@ def write_batch( else: col_try_type = try_features[col] if try_features is not None and col in try_features else None typed_sequence = OptimizedTypedSequence(col_values, type=col_type, try_type=col_try_type, col=col) - arrays.append(pa.array(typed_sequence)) - inferred_features[col] = typed_sequence.get_inferred_type() + array = pa.array(typed_sequence) + arrays.append(array) + inferred_features[col] = generate_from_arrow_type(array.type) schema = inferred_features.arrow_schema if self.pa_writer is None else self.schema pa_table = pa.Table.from_arrays(arrays, schema=schema) self.write_table(pa_table, writer_batch_size) From 052443de97825c75f92335cc56adde36ab2652c9 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Fri, 5 Apr 2024 17:58:38 +0200 Subject: [PATCH 2/5] change cast_storage of image to directly convert pyarrow array instead of going to pylist --- src/datasets/features/image.py | 74 ++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 8d573bb1021..70b69dc071c 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -225,31 +225,73 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr `pa.StructArray`: Array in the Image arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. """ - if pa.types.is_string(storage.type): - bytes_array = pa.array([None] * len(storage), type=pa.binary()) - storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) - elif pa.types.is_binary(storage.type): - path_array = pa.array([None] * len(storage), type=pa.string()) - storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) - elif pa.types.is_struct(storage.type): - if storage.type.get_field_index("bytes") >= 0: - bytes_array = storage.field("bytes") - else: + if hasattr(storage, "type"): + if pa.types.is_string(storage.type): bytes_array = pa.array([None] * len(storage), type=pa.binary()) - if storage.type.get_field_index("path") >= 0: - path_array = storage.field("path") - else: + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays( + [bytes_array, path_array], ["bytes", "path"], mask=storage.is_null() + ) + elif pa.types.is_list(storage.type): + from .features import Array2DExtensionType, Array3DExtensionType + + def get_shapes(arr): + shape = () + while isinstance(arr, pa.ListArray): + len_curr = len(arr) + arr = arr.flatten() + len_new = len(arr) + shape = shape + (len_new // len_curr,) + return shape + + def get_dtypes(arr): + dtype = storage.type + while hasattr(dtype, "value_type"): + dtype = dtype.value_type + return dtype + + shapes = [get_shapes(storage.take([i])) for i in range(len(storage))] + dtypes = [get_dtypes(storage.take([i])) for i in range(len(storage))] + + extension_types = [ + Array3DExtensionType(shape=shape, dtype=str(dtype)) + for shape, dtype in zip(shapes, dtypes, strict=True) + ] + arrays = [ + pa.ExtensionArray.from_storage(extension_type, storage.take([i])).to_numpy().squeeze(0) + for i, extension_type in enumerate(extension_types) + ] + + bytes_array = pa.array( + [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays], + type=pa.binary(), + ) path_array = pa.array([None] * len(storage), type=pa.string()) - storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) - elif pa.types.is_list(storage.type): + storage = pa.StructArray.from_arrays( + [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null() + ) + else: bytes_array = pa.array( - [encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()], + [encode_np_array(arr)["bytes"] if arr is not None else None for arr in storage], type=pa.binary(), ) path_array = pa.array([None] * len(storage), type=pa.string()) storage = pa.StructArray.from_arrays( [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null() ) + return array_cast(storage, self.pa_type) def embed_storage(self, storage: pa.StructArray) -> pa.StructArray: From 3e80a3455e10f5b371da0ab873bc2f69704f2337 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Fri, 5 Apr 2024 18:53:14 +0200 Subject: [PATCH 3/5] faster Image cast_storage for pa.ListArray --- src/datasets/features/image.py | 59 ++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 70b69dc071c..a2c08d6c817 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -245,34 +245,20 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr [bytes_array, path_array], ["bytes", "path"], mask=storage.is_null() ) elif pa.types.is_list(storage.type): - from .features import Array2DExtensionType, Array3DExtensionType - - def get_shapes(arr): - shape = () - while isinstance(arr, pa.ListArray): - len_curr = len(arr) - arr = arr.flatten() - len_new = len(arr) - shape = shape + (len_new // len_curr,) - return shape - - def get_dtypes(arr): - dtype = storage.type - while hasattr(dtype, "value_type"): - dtype = dtype.value_type - return dtype - - shapes = [get_shapes(storage.take([i])) for i in range(len(storage))] - dtypes = [get_dtypes(storage.take([i])) for i in range(len(storage))] - - extension_types = [ - Array3DExtensionType(shape=shape, dtype=str(dtype)) - for shape, dtype in zip(shapes, dtypes, strict=True) - ] - arrays = [ - pa.ExtensionArray.from_storage(extension_type, storage.take([i])).to_numpy().squeeze(0) - for i, extension_type in enumerate(extension_types) - ] + from .features import Array3DExtensionType + + arrays = [] + for i, is_null in enumerate(storage.is_null()): + if is_null: + arrays.append(None) + else: + storage_part = storage.take([i]) + shape = get_shapes_from_listarray(storage_part) + dtype = get_dtypes_from_listarray(storage_part) + + extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype)) + array = pa.ExtensionArray.from_storage(extension_type, storage_part) + arrays.append(array.to_numpy().squeeze(0)) bytes_array = pa.array( [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays], @@ -327,6 +313,23 @@ def path_to_bytes(path): return array_cast(storage, self.pa_type) +def get_shapes_from_listarray(listarray: pa.ListArray): + shape = () + while isinstance(listarray, pa.ListArray): + len_curr = len(listarray) + listarray = listarray.flatten() + len_new = len(listarray) + shape = shape + (len_new // len_curr,) + return shape + + +def get_dtypes_from_listarray(listarray: pa.ListArray): + dtype = listarray.type + while hasattr(dtype, "value_type"): + dtype = dtype.value_type + return dtype + + def list_image_compression_formats() -> List[str]: if config.PIL_AVAILABLE: import PIL.Image From 7005edcc1cede87200a4751f06b8c014470097f8 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Fri, 5 Apr 2024 18:58:30 +0200 Subject: [PATCH 4/5] fix case where image is 2d --- src/datasets/features/image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index a2c08d6c817..05c02daa076 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -245,7 +245,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr [bytes_array, path_array], ["bytes", "path"], mask=storage.is_null() ) elif pa.types.is_list(storage.type): - from .features import Array3DExtensionType + from .features import Array2DExtensionType, Array3DExtensionType arrays = [] for i, is_null in enumerate(storage.is_null()): @@ -256,7 +256,10 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr shape = get_shapes_from_listarray(storage_part) dtype = get_dtypes_from_listarray(storage_part) - extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype)) + if len(shape) == 2: + extension_type = Array2DExtensionType(shape=shape, dtype=str(dtype)) + else: + extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype)) array = pa.ExtensionArray.from_storage(extension_type, storage_part) arrays.append(array.to_numpy().squeeze(0)) From 0551685a98d3c6ba8ab74a29054b574462b47dc9 Mon Sep 17 00:00:00 2001 From: Mike Boss Date: Mon, 8 Apr 2024 11:18:38 +0200 Subject: [PATCH 5/5] pa bool as py bool and revert accidental change --- src/datasets/features/image.py | 78 ++++++++++++++-------------------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 05c02daa076..284e3631ee9 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -225,55 +225,43 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr `pa.StructArray`: Array in the Image arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. """ - if hasattr(storage, "type"): - if pa.types.is_string(storage.type): + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: bytes_array = pa.array([None] * len(storage), type=pa.binary()) - storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) - elif pa.types.is_binary(storage.type): + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: path_array = pa.array([None] * len(storage), type=pa.string()) - storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) - elif pa.types.is_struct(storage.type): - if storage.type.get_field_index("bytes") >= 0: - bytes_array = storage.field("bytes") - else: - bytes_array = pa.array([None] * len(storage), type=pa.binary()) - if storage.type.get_field_index("path") >= 0: - path_array = storage.field("path") - else: - path_array = pa.array([None] * len(storage), type=pa.string()) - storage = pa.StructArray.from_arrays( - [bytes_array, path_array], ["bytes", "path"], mask=storage.is_null() - ) - elif pa.types.is_list(storage.type): - from .features import Array2DExtensionType, Array3DExtensionType - - arrays = [] - for i, is_null in enumerate(storage.is_null()): - if is_null: - arrays.append(None) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_list(storage.type): + from .features import Array2DExtensionType, Array3DExtensionType + + arrays = [] + for i, is_null in enumerate(storage.is_null()): + if not is_null.as_py(): + storage_part = storage.take([i]) + shape = get_shapes_from_listarray(storage_part) + dtype = get_dtypes_from_listarray(storage_part) + + if len(shape) == 2: + extension_type = Array2DExtensionType(shape=shape, dtype=str(dtype)) else: - storage_part = storage.take([i]) - shape = get_shapes_from_listarray(storage_part) - dtype = get_dtypes_from_listarray(storage_part) - - if len(shape) == 2: - extension_type = Array2DExtensionType(shape=shape, dtype=str(dtype)) - else: - extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype)) - array = pa.ExtensionArray.from_storage(extension_type, storage_part) - arrays.append(array.to_numpy().squeeze(0)) - - bytes_array = pa.array( - [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays], - type=pa.binary(), - ) - path_array = pa.array([None] * len(storage), type=pa.string()) - storage = pa.StructArray.from_arrays( - [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null() - ) - else: + extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype)) + array = pa.ExtensionArray.from_storage(extension_type, storage_part) + arrays.append(array.to_numpy().squeeze(0)) + else: + arrays.append(None) + bytes_array = pa.array( - [encode_np_array(arr)["bytes"] if arr is not None else None for arr in storage], + [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays], type=pa.binary(), ) path_array = pa.array([None] * len(storage), type=pa.string())