Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Image cast storage faster #6786

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,9 @@ def write_batch(
else:
col_try_type = try_features[col] if try_features is not None and col in try_features else None
typed_sequence = OptimizedTypedSequence(col_values, type=col_type, try_type=col_try_type, col=col)
arrays.append(pa.array(typed_sequence))
inferred_features[col] = typed_sequence.get_inferred_type()
array = pa.array(typed_sequence)
arrays.append(array)
inferred_features[col] = generate_from_arrow_type(array.type)
schema = inferred_features.arrow_schema if self.pa_writer is None else self.schema
pa_table = pa.Table.from_arrays(arrays, schema=schema)
self.write_table(pa_table, writer_batch_size)
Expand Down
38 changes: 37 additions & 1 deletion src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,33 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
elif pa.types.is_list(storage.type):
from .features import Array2DExtensionType, Array3DExtensionType

arrays = []
for i, is_null in enumerate(storage.is_null()):
if not is_null.as_py():
storage_part = storage.take([i])
shape = get_shapes_from_listarray(storage_part)
dtype = get_dtypes_from_listarray(storage_part)

if len(shape) == 2:
extension_type = Array2DExtensionType(shape=shape, dtype=str(dtype))
else:
extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype))
array = pa.ExtensionArray.from_storage(extension_type, storage_part)
arrays.append(array.to_numpy().squeeze(0))
else:
arrays.append(None)

bytes_array = pa.array(
[encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()],
[encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays],
type=pa.binary(),
)
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays(
[bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
)

return array_cast(storage, self.pa_type)

def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
Expand Down Expand Up @@ -284,6 +303,23 @@ def path_to_bytes(path):
return array_cast(storage, self.pa_type)


def get_shapes_from_listarray(listarray: pa.ListArray):
shape = ()
while isinstance(listarray, pa.ListArray):
len_curr = len(listarray)
listarray = listarray.flatten()
len_new = len(listarray)
shape = shape + (len_new // len_curr,)
return shape


def get_dtypes_from_listarray(listarray: pa.ListArray):
dtype = listarray.type
while hasattr(dtype, "value_type"):
dtype = dtype.value_type
return dtype


def list_image_compression_formats() -> List[str]:
if config.PIL_AVAILABLE:
import PIL.Image
Expand Down
Loading