Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature dataset checkpoint strategy #194

Merged
merged 29 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d611cee
Add save and load methods to disk
plaguss Dec 22, 2023
cd00186
Add DatasetCheckpoint to deal with saves to disk during generation pi…
plaguss Dec 22, 2023
952a437
Add tests for new dataset methods and do_checkpoint strategy
plaguss Dec 22, 2023
710ccfa
Add checkpoint strategy to pipeline
plaguss Dec 22, 2023
f6718dc
Fix error with typing in python 3.8
plaguss Dec 22, 2023
eda97be
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Dec 22, 2023
769e20a
Add type
plaguss Dec 22, 2023
3487b06
Add type
plaguss Dec 22, 2023
d4bda8a
Add type
plaguss Dec 22, 2023
0dae781
Add extra test for do_checkpoint and move DatasetCheckpoint to distil…
plaguss Dec 22, 2023
5ff9180
Simplify code in pipeline as per code review
plaguss Dec 22, 2023
2435476
Merge incoming changes from github ui
plaguss Dec 22, 2023
51f4792
Add task filename as a constant variable
plaguss Dec 22, 2023
b867f31
Fix import in pipeline
plaguss Dec 22, 2023
8a409ef
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Dec 22, 2023
28b175b
Merge with main
plaguss Jan 2, 2024
979e1f9
Use dill instead of pickle to simplify process with custom classes
plaguss Jan 3, 2024
9459b15
Add extra check for shape of saved dataset
plaguss Jan 3, 2024
44f732d
Remove functionality of _build_dataset that prevented the generation …
plaguss Jan 3, 2024
d20ffa4
Add documentatio for the checkpoint strategy
plaguss Jan 4, 2024
293610d
Rename the default folder to standardize
plaguss Jan 4, 2024
dc8603a
Update src/distilabel/pipeline.py
plaguss Jan 4, 2024
b834ca9
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
22b4cc4
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
95dac2c
Merge with main and corrected typo
plaguss Jan 8, 2024
b8cc8cc
Update ultrafeedback from tests
plaguss Jan 8, 2024
8db52df
Update with comments from the review
plaguss Jan 8, 2024
f8e8606
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
d9ad890
Remove dead code
plaguss Jan 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pathlib import Path
from distilabel.dataset import DatasetCheckpoint

# Assuming we want to save the dataset every 10% of the records generated.

freq = len(dataset) // 10
dataset_checkpoint = DatasetCheckpoint(path=Path.cwd() / "checkpoint_folder", save_frequency=freq)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
new_ds = pipe.generate(
dataset=dataset,
num_generations=1,
checkpoint_strategy=dataset_checkpoint,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from distilabel.dataset import CustomDataset
new_ds = CustomDataset.load_from_disk(dataset_checkpoint.path)
22 changes: 22 additions & 0 deletions docs/technical-reference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,28 @@ Then, we will load the dataset and call the `generate` method of the pipeline. F

With a few lines of code, we have easily generated a dataset with 2 generations per input, using 4 different `LLM`s, and labelled the generations using GPT-4. You can check the full code [here](https://github.com/argilla-io/distilabel/blob/main/examples/pipeline-preference-dataset-llmpool.py).

### Dataset checkpoints

With long pipelines, it may be useful to review the dataset during the process, just to keep track and stop the process, or have it saved in case something fails before obtaining the final dataset. We can use the `checkpoint_strategy` in `Pipeline.generate` method for this end:
plaguss marked this conversation as resolved.
Show resolved Hide resolved

```python
--8<-- "docs/snippets/technical-reference/pipeline/pipeline_dataset_checkpoint_1.py"
```

By passing the checkpoint strategy to the `generate` method, the dataset will be saved to disk automatically every *freq* generations:

```python
--8<-- "docs/snippets/technical-reference/pipeline/pipeline_dataset_checkpoint_2.py"
```

The dataset can be regenerated from the checkpoint by simply calling the `CustomDataset.load_from_disk` method.

```python
--8<-- "docs/snippets/technical-reference/pipeline/pipeline_dataset_checkpoint_3.py"
```

And with the dataset regenerated we can easily call `push_to_argilla` on it to review it.

## pipeline

Considering recurring patterns in dataset creation, we can facilitate the process by utilizing the [`Pipeline`][distilabel.pipeline.Pipeline]. This is made simpler through the [`pipeline`][distilabel.pipeline.pipeline] function, which provides the necessary parameters for creating a `Pipeline`.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"tenacity >= 8",
"importlib-resources >= 6.1.1; python_version < '3.9'",
"multiprocess",
"dill >= 0.3.7",
plaguss marked this conversation as resolved.
Show resolved Hide resolved
]
dynamic = ["version"]

Expand Down
77 changes: 76 additions & 1 deletion src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Union
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Union

from datasets import Dataset

from distilabel.utils.dataset import load_task_from_disk, save_task_to_disk
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,3 +94,74 @@ def to_argilla(self) -> "FeedbackDataset":
stacklevel=2,
)
return rg_dataset

def save_to_disk(self, dataset_path: PathLike, **kwargs: Any) -> None:
"""Saves the datataset to disk, also saving the task.

Args:
dataset_path: Path to the dataset.
**kwargs: Additional arguments to be passed to `datasets.Dataset.save_to_disk`.
"""
super().save_to_disk(dataset_path, **kwargs)
if self.task is not None:
save_task_to_disk(dataset_path, self.task)

@classmethod
def load_from_disk(cls, dataset_path: PathLike, **kwargs: Any):
"""Load a CustomDataset from disk, also reading the task.

Args:
dataset_path: Path to the dataset, as you would do with a standard Dataset.

Returns:
The loaded dataset.
"""
ds = super().load_from_disk(dataset_path, *kwargs)
# Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance
ds.__class__ = cls
task = load_task_from_disk(dataset_path)
ds.task = task
return ds


@dataclass
class DatasetCheckpoint:
"""A checkpoint class that contains the information of a checkpoint.

Args:
path (Path): The path to the checkpoint.
save_frequency (int): The frequency at which the checkpoint should be saved
By default is set to -1 (no checkpoint is saved to disk, but the dataset
is returned upon failure).
extra_kwargs (dict[str, Any]): Additional kwargs to be passed to the `save_to_disk` method of the Dataset.

Examples:
>>> from distilabel.dataset import DatasetCheckpoint
>>> # Save the dataset every 10% of the records generated.
>>> checkpoint = DatasetCheckpoint(save_frequency=len(dataset) // 10)
>>> # Afterwards, we can access the checkpoint checkpoint.path.
"""

path: Path = Path.cwd() / "ckpt"
save_frequency: int = -1
extra_kwargs: Dict[str, Any] = field(default_factory=dict)

# Internal fields to keep track of the number of records generated and when to check.
_total_checks: int = field(repr=False, default=0)

def do_checkpoint(self, step: int) -> bool:
"""Determines if a checkpoint should be done.

Args:
step (int): The number of records generated.

Returns:
bool: Whether a checkpoint should be done.
"""
if self.save_frequency == -1:
return False

if (step - self._total_checks * self.save_frequency) // self.save_frequency:
self._total_checks += 1
return True
return False
68 changes: 48 additions & 20 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from datasets import Dataset, Split

from distilabel.dataset import CustomDataset
from distilabel.dataset import CustomDataset, DatasetCheckpoint
from distilabel.llm.base import LLM, LLMPool, ProcessLLM
from distilabel.llm.utils import LLMOutput
from distilabel.logger import get_logger
Expand Down Expand Up @@ -425,13 +425,13 @@ def _build_dataset( # noqa: C901
"raw_generation_responses",
] + self.generator.task.output_args_names

if len(generations) < len(dataset):
generations.extend(
[
{key: None for key in generator_column_names}
for _ in range(len(dataset) - len(generations))
]
)
# if len(generations) < len(dataset):
# generations.extend(
# [
# {key: None for key in generator_column_names}
# for _ in range(len(dataset) - len(generations))
# ]
# )
plaguss marked this conversation as resolved.
Show resolved Hide resolved

# Add missing keys/columns with a `None` value
for generation in generations:
Expand Down Expand Up @@ -527,12 +527,11 @@ def _generate( # noqa: C901
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
checkpoint_strategy: Optional[DatasetCheckpoint] = DatasetCheckpoint(),
display_progress_bar: bool = False,
) -> CustomDataset:
"""Generates the outputs for the given dataset using the LLMs provided to the
`Pipeline`."""

if (
self.labeller is not None
and self.generator is not None
Expand Down Expand Up @@ -581,14 +580,16 @@ def _generate( # noqa: C901
progress_callback_func=generation_progress_func,
)
generations.extend(batch_generations)

except Exception as e:
if not enable_checkpoints:
if not checkpoint_strategy:
raise RuntimeError(
"`Pipeline.generate` failed during generation step. Setting `enable_checkpoints=True` is recommended!"
"`Pipeline.generate` failed during generation step. Passing a `DatasetCheckpoint` is recommended!"
) from e
logger.error(
f"`Pipeline.generate` failed during generation step with exception: {e}"
)

return self._build_dataset(
dataset,
generations=generations,
Expand All @@ -611,26 +612,51 @@ def _generate( # noqa: C901
labels.append(batch_labels) # type: ignore
else:
labels.extend(batch_labels) # type: ignore

except Exception as e:
if not enable_checkpoints:
if not checkpoint_strategy:
raise RuntimeError(
"`Pipeline.generate` failed during labelling step. Setting `enable_checkpoints=True` is recommended!"
"`Pipeline.generate` failed during labelling step. Passing a `DatasetCheckpoint` is recommended!"
) from e
logger.error(
f"`Pipeline.generate` failed during labelling step with exception: {e}"
)

return self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)

if checkpoint_strategy and checkpoint_strategy.do_checkpoint(
batch_i * batch_size
):
logger.info(f"Saving dataset up to batch {batch_i}...")
plaguss marked this conversation as resolved.
Show resolved Hide resolved
ds = self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)
ds.save_to_disk(
checkpoint_strategy.path,
**checkpoint_strategy.extra_kwargs,
)

_pipeline_progress.stop()

return self._build_dataset(
ds = self._build_dataset(
dataset, generations=generations, labels=labels, batch_size=batch_size
)
if checkpoint_strategy:
logger.info("Saving final dataset...")
ds.save_to_disk(
checkpoint_strategy.path,
**checkpoint_strategy.extra_kwargs,
)

plaguss marked this conversation as resolved.
Show resolved Hide resolved
return ds

def dry_run(self, dataset: Dataset) -> CustomDataset:
"""Performs a dry run over the provided dataset, which consists on generating the
Expand All @@ -655,7 +681,7 @@ def dry_run(self, dataset: Dataset) -> CustomDataset:
# Default kwargs to make the process as simple as possible
num_generations=1,
batch_size=1,
enable_checkpoints=False,
checkpoint_strategy=None,
display_progress_bar=False,
)
except Exception as e:
Expand All @@ -670,7 +696,7 @@ def generate(
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
checkpoint_strategy: Optional[DatasetCheckpoint] = DatasetCheckpoint(),
display_progress_bar: bool = False,
skip_dry_run: bool = False,
) -> CustomDataset:
Expand All @@ -684,7 +710,9 @@ def generate(
shuffle_before_labelling: whether to shuffle the generations before labelling
or not. This is useful to avoid the labelling LLM to be biased by the order
of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`.
checkpoint_strategy (DatasetCheckpoint, optional): the checkpoint strategy.
If `None` is provided, no checkpoints will be saved. Defaults to `DatasetCheckpoint()`,
which won't save the dataset but returns the generated dataset upon failure.
display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`.
skip_dry_run (bool, optional): whether to skip the dry run or not. Defaults to `False`.

Expand All @@ -694,7 +722,7 @@ def generate(
Raises:
RuntimeError: if the `Pipeline` fails during the generation or labelling steps.
UserWarning: if the `Pipeline` fails during the generation or labelling steps and
`enable_checkpoints` is set to `False`.
`checkpoint_strategy` is set to `None`.

Examples:
>>> from transformers import AutoModelForCaualLM, AutoTokenizer
Expand Down Expand Up @@ -725,7 +753,7 @@ def generate(
dataset=dataset,
num_generations=num_generations,
batch_size=batch_size,
enable_checkpoints=enable_checkpoints,
checkpoint_strategy=checkpoint_strategy,
shuffle_before_labelling=shuffle_before_labelling,
display_progress_bar=display_progress_bar,
)
Expand Down
53 changes: 53 additions & 0 deletions src/distilabel/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import TYPE_CHECKING

import dill as pickle

if TYPE_CHECKING:
from distilabel.tasks.base import Task


TASK_FILE_NAME = "task.pkl"


def save_task_to_disk(path: Path, task: "Task") -> None:
"""Saves a task to disk.

Args:
path: The path to the task.
task: The task.
"""
task_path = path / TASK_FILE_NAME
with open(task_path, "wb") as f:
f.write(pickle.dumps(task))


def load_task_from_disk(path: Path) -> "Task":
"""Loads a task from disk.

Args:
path: The path to the task.

Returns:
Task: The task.
"""
task_path = path / "task.pkl"
if not task_path.exists():
raise FileNotFoundError(f"The task file does not exist: {task_path}")
with open(task_path, "rb") as f:
task = pickle.loads(f.read())
return task
Loading