Skip to content

Commit

Permalink
Feat: Basic datasets getitem (#628)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
codingl2k1 and mergify[bot] authored Jul 28, 2023
1 parent 78f915a commit 02bebe1
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 46 deletions.
7 changes: 5 additions & 2 deletions doc/source/libraries/xorbits_data/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ improve concurrency.
>>> return example
>>> # 10 processes applying `add_prefix` concurrently.
>>> dataset = dataset.map(add_prefix)
>>> # Currently, you have to fetch() to get the dataset info.
>>> dataset.fetch()
>>> dataset
Dataset({
features: ['text', 'label'],
num_rows: 8530
})
>>> dataset[1:3]["text"]
['Xorbits: the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .',
'Xorbits: effective but too-tepid biopic']


Datasets Outputs
----------------
Expand Down
11 changes: 11 additions & 0 deletions python/xorbits/_mars/services/meta/metas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

import numpy as np

from ....datasets.backends.huggingface.core import (
HuggingfaceDatasetChunk,
HuggingfaceDatasetChunkData,
)
from ...core import OBJECT_CHUNK_TYPE, OBJECT_TYPE
from ...dataframe.core import (
CATEGORICAL_CHUNK_TYPE,
Expand Down Expand Up @@ -191,6 +195,13 @@ class ObjectChunkMeta(_ChunkMeta):
pass


@register_meta_type((HuggingfaceDatasetChunk, HuggingfaceDatasetChunkData))
@dataslots
@dataclass
class DatasetChunkMeta(_ChunkMeta):
shape: Tuple[int] = None


@register_meta_type(DATAFRAME_OR_SERIES_TYPE)
@dataslots
@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,22 @@ def _check_nsplits(self, tiled: TileableType):
% (c.op, c.index, tiled.nsplits)
)
for cid, shape in enumerate(itertools.product(*tiled.nsplits)):
tiled_chunk = tiled.chunks[cid]
if not hasattr(tiled_chunk, "shape"):
continue
chunk_shape = (
self._raw_chunk_shapes.get(tiled.chunks[cid].key)
or tiled.chunks[cid].shape
self._raw_chunk_shapes.get(tiled_chunk.key) or tiled_chunk.shape
)
if len(shape) != len(chunk_shape):
raise AssertionError(
"Operand %r: Shape in nsplits %r does not meet shape in chunk %r"
% (tiled.chunks[cid].op, shape, chunk_shape)
% (tiled_chunk.op, shape, chunk_shape)
)
for s1, s2 in zip(shape, chunk_shape):
if (not (np.isnan(s1) and np.isnan(s2))) and s1 != s2:
raise AssertionError(
"Operand %r: Shape in nsplits %r does not meet shape in chunk %r"
% (tiled.chunks[cid].op, shape, chunk_shape)
% (tiled_chunk.op, shape, chunk_shape)
)

def post_chunk_graph_execution(self):
Expand Down
28 changes: 28 additions & 0 deletions python/xorbits/_mars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,34 @@ def merge_chunks(chunk_results: List[Tuple[Tuple[int], Any]]) -> Any:
if len(result) == 1:
return result[0]
return result
elif type(v) is list:
result = []
for r in chunk_results:
result.extend(r[1])
return result
elif type(v) is dict:
# TODO(codingl2k1) : We should register a merge handler for each output type.
result = {}
chunk_results = [(k, v) for k, v in chunk_results if v]
if len(chunk_results) == 1:
return chunk_results[0][1]
for r in chunk_results:
d = r[1]
if not result:
if not all(
type(key) is str and type(value) is list for key, value in d.items()
):
raise TypeError(
"only support merge dict with type Dict[str, List]."
)
result.update(d)
else:
if d.keys() != result.keys():
raise TypeError(f"unsupported merge dict with different keys.")
else:
for key, value in d.items():
result[key].extend(value)
return result
elif hf_datasets is not None and isinstance(v, hf_datasets.Dataset):
result = [r[1] for r in chunk_results]
return hf_datasets.concatenate_datasets(result)
Expand Down
58 changes: 49 additions & 9 deletions python/xorbits/datasets/backends/huggingface/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Mapping, Optional, Sequence, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Union

try:
# For type hint.
Expand All @@ -26,47 +26,87 @@
Version,
)
except ImportError:
from typing import Any

DownloadConfig = Any
DownloadMode = Any
Features = Any
Split = Any
VerificationMode = Any
Version = Any

from ...._mars.core import is_build_mode
from ...._mars.core.entity import (
OutputType,
register_fetch_class,
register_output_types,
)
from ...._mars.core.entity.utils import refresh_tileable_shape
from ...._mars.core.operand.objects import ObjectFetch
from ...._mars.serialization.serializables import FieldTypes, ListField
from ....utils import check_signature_compatible, get_non_default_kwargs
from ...dataset import Dataset, DatasetChunk, DatasetChunkData, DatasetData
from .getitem import getitem
from .loader import load_huggingface_dataset
from .map import map
from .rechunk import rechunk
from .to_dataframe import to_dataframe


class HuggingfaceDatasetChunkData(DatasetChunkData):
__slots__ = ()
type_name = "HuggingfaceDatasetChunkData"

@classmethod
def get_params_from_data(cls, data) -> Dict[str, Any]:
"""For updating chunk shape from data."""
return {"shape": data.shape}


class HuggingfaceDatasetChunk(DatasetChunk):
__slots__ = ()
_allow_data_type_ = (HuggingfaceDatasetChunkData,)
type_name = "HuggingfaceDatasetChunk"


class HuggingfaceDatasetData(DatasetData):
__slots__ = ()
type_name = "Huggingface Dataset"

_chunks = ListField(
"chunks",
FieldTypes.reference(HuggingfaceDatasetChunk),
on_serialize=lambda x: [it.data for it in x] if x is not None else x,
on_deserialize=lambda x: [HuggingfaceDatasetChunk(it) for it in x]
if x is not None
else x,
)

def __repr__(self):
try:
return f"Dataset({{\n features: {self.dtypes.index.values.tolist()},\n num_rows: {self.shape[0]}\n}})"
except: # noqa: E722 # nosec # pylint: disable=bare-except # pragma: no cover
if is_build_mode() or len(self._executed_sessions) == 0:
# in build mode, or not executed, just return representation
return f"Huggingface Dataset <op={type(self.op).__name__}, key={self.key}>"
else:
try:
return f"Dataset({{\n features: {self.dtypes.index.values.tolist()},\n num_rows: {self.shape[0]}\n}})"
except: # noqa: E722 # nosec # pylint: disable=bare-except # pragma: no cover
return (
f"Huggingface Dataset <op={type(self.op).__name__}, key={self.key}>"
)

def refresh_params(self):
refresh_tileable_shape(self)
# TODO(codingl2k1): update dtypes.

def rechunk(self, num_chunks: int, **kwargs):
return rechunk(self, num_chunks, **kwargs)

def map(self, fn, **kwargs):
return map(self, fn, **kwargs)

def to_dataframe(self):
return to_dataframe(self)
def to_dataframe(self, types_mapper=None):
return to_dataframe(self, types_mapper)

def __getitem__(self, item: Union[int, slice, str]):
return getitem(self, item)


class HuggingfaceDataset(Dataset):
Expand All @@ -81,7 +121,7 @@ def to_dataset(self):
register_output_types(
OutputType.huggingface_dataset,
(HuggingfaceDataset, HuggingfaceDatasetData),
(DatasetChunk, DatasetChunkData),
(HuggingfaceDatasetChunk, HuggingfaceDatasetChunkData),
)


Expand Down
134 changes: 134 additions & 0 deletions python/xorbits/datasets/backends/huggingface/getitem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from ...._mars.core.entity import OutputType
from ...._mars.serialization.serializables import AnyField
from ...._mars.utils import has_unknown_shape, is_full_slice
from ...operand import DataOperand, DataOperandMixin


class HuggingfaceGetItem(DataOperand, DataOperandMixin):
hf_getitem_key = AnyField("hf_getitem_key")

def __call__(self, dataset):
return self.new_tileable([dataset], **dataset.params)

@classmethod
def _gen_copy_output(cls, inp, op: "HuggingfaceGetItem"):
out_chunks = []
for chunk in inp.chunks:
chunk_op = op.copy().reset_key()
out_chunk = chunk_op.new_chunk([chunk], index=chunk.index)
out_chunks.append(out_chunk)
return op.copy().new_tileable(op.inputs, chunks=out_chunks)

@classmethod
def _gen_empty_output(cls, inp, op: "HuggingfaceGetItem"):
first_chunk = inp.chunks[0]
chunk_op = op.copy().reset_key()
chunk_op.hf_getitem_key = slice(0, 0)
out_chunk = chunk_op.new_chunk([first_chunk], index=first_chunk.index)
return op.copy().new_tileable(op.inputs, chunks=[out_chunk])

@classmethod
def _find_chunk_index_by_key(cls, chunks, key):
for idx, chunk in enumerate(chunks):
if key < 0: # pragma: no cover
raise IndexError(f"Input key {key} is out of bound.")
if key >= chunk.shape[0]:
key -= chunk.shape[0]
else:
return idx, key
# pragma: no cover
raise IndexError(f"Input key {key} is out of bound.")

@classmethod
def tile(cls, op: "HuggingfaceGetItem"):
assert len(op.inputs) == 1
inp = op.inputs[0]

if isinstance(op.hf_getitem_key, str):
return cls._gen_copy_output(inp, op)
elif isinstance(op.hf_getitem_key, int):
if has_unknown_shape(*op.inputs):
yield
index, key = cls._find_chunk_index_by_key(inp.chunks, op.hf_getitem_key)
chunk = inp.chunks[index]
chunk_op = op.copy().reset_key()
chunk_op.hf_getitem_key = key
out_chunk = chunk_op.new_chunk([chunk], index=chunk.index)
return op.copy().new_tileable(op.inputs, chunks=[out_chunk])
elif isinstance(op.hf_getitem_key, slice):
if is_full_slice(op.hf_getitem_key):
return cls._gen_copy_output(inp, op)
else:
start = op.hf_getitem_key.start
stop = op.hf_getitem_key.stop
assert op.hf_getitem_key.step is None
if start >= stop:
# For empty slice, e.g. s[3:1], s[3:3], we translate the
# execution to the first chunk[0:0].
return cls._gen_empty_output(inp, op)
else:
if has_unknown_shape(*op.inputs):
yield
try:
start_index, start_key = cls._find_chunk_index_by_key(
inp.chunks, start
)
except IndexError:
return cls._gen_empty_output(inp, op)
try:
stop_index, stop_key = cls._find_chunk_index_by_key(
inp.chunks, stop
)
except IndexError:
stop_index = len(inp.chunks) - 1
stop_key = None
chunks = []
for index, chunk in enumerate(inp.chunks):
if start_index <= index <= stop_index:
chunk_op = op.copy().reset_key()
slice_start = start_key if index == start_index else None
slice_stop = stop_key if index == stop_index else None
chunk_op.hf_getitem_key = slice(
slice_start, slice_stop, None
)
out_chunk = chunk_op.new_chunk([chunk], index=chunk.index)
chunks.append(out_chunk)
elif index > stop_index:
break
return op.copy().new_tileable(op.inputs, chunks=chunks)
else: # pragma: no cover
raise NotImplementedError(
f"Not support getitem with key type: {type(op.hf_getitem_key)}"
)

@classmethod
def execute(cls, ctx, op: "HuggingfaceGetItem"):
inp = ctx[op.inputs[0].key]
out = op.outputs[0]
ctx[out.key] = inp.__getitem__(op.hf_getitem_key)


def getitem(dataset, key: Union[int, slice, str]):
if not isinstance(key, (str, int, slice)):
raise NotImplementedError(f"Not support getitem with key type: {type(key)}")
if isinstance(key, slice) and key.step is not None:
raise NotImplementedError(f"Not support getitem with slice and step: {key}")
op = HuggingfaceGetItem(output_types=[OutputType.object], hf_getitem_key=key)
return op(dataset).execute().fetch()
Loading

0 comments on commit 02bebe1

Please sign in to comment.