Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature dataset checkpoint strategy #194

Merged
merged 29 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d611cee
Add save and load methods to disk
plaguss Dec 22, 2023
cd00186
Add DatasetCheckpoint to deal with saves to disk during generation pi…
plaguss Dec 22, 2023
952a437
Add tests for new dataset methods and do_checkpoint strategy
plaguss Dec 22, 2023
710ccfa
Add checkpoint strategy to pipeline
plaguss Dec 22, 2023
f6718dc
Fix error with typing in python 3.8
plaguss Dec 22, 2023
eda97be
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Dec 22, 2023
769e20a
Add type
plaguss Dec 22, 2023
3487b06
Add type
plaguss Dec 22, 2023
d4bda8a
Add type
plaguss Dec 22, 2023
0dae781
Add extra test for do_checkpoint and move DatasetCheckpoint to distil…
plaguss Dec 22, 2023
5ff9180
Simplify code in pipeline as per code review
plaguss Dec 22, 2023
2435476
Merge incoming changes from github ui
plaguss Dec 22, 2023
51f4792
Add task filename as a constant variable
plaguss Dec 22, 2023
b867f31
Fix import in pipeline
plaguss Dec 22, 2023
8a409ef
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Dec 22, 2023
28b175b
Merge with main
plaguss Jan 2, 2024
979e1f9
Use dill instead of pickle to simplify process with custom classes
plaguss Jan 3, 2024
9459b15
Add extra check for shape of saved dataset
plaguss Jan 3, 2024
44f732d
Remove functionality of _build_dataset that prevented the generation …
plaguss Jan 3, 2024
d20ffa4
Add documentatio for the checkpoint strategy
plaguss Jan 4, 2024
293610d
Rename the default folder to standardize
plaguss Jan 4, 2024
dc8603a
Update src/distilabel/pipeline.py
plaguss Jan 4, 2024
b834ca9
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
22b4cc4
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
95dac2c
Merge with main and corrected typo
plaguss Jan 8, 2024
b8cc8cc
Update ultrafeedback from tests
plaguss Jan 8, 2024
8db52df
Update with comments from the review
plaguss Jan 8, 2024
f8e8606
Merge branch 'main' of https://github.com/argilla-io/distilabel into …
plaguss Jan 8, 2024
d9ad890
Remove dead code
plaguss Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/distilabel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

import warnings
from os import PathLike
from typing import TYPE_CHECKING, Union
plaguss marked this conversation as resolved.
Show resolved Hide resolved

from datasets import Dataset

from distilabel.utils.dataset import load_task_from_disk, save_task_to_disk
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,3 +92,31 @@ def to_argilla(self) -> "FeedbackDataset":
stacklevel=2,
)
return rg_dataset

def save_to_disk(self, dataset_path: PathLike, **kwargs) -> None:
plaguss marked this conversation as resolved.
Show resolved Hide resolved
"""Saves the datataset to disk, also saving the task.

Args:
dataset_path: Path to the dataset.
**kwargs: Additional arguments to be passed to `datasets.Dataset.save_to_disk`.
"""
super().save_to_disk(dataset_path, **kwargs)
if self.task is not None:
save_task_to_disk(dataset_path, self.task)

@classmethod
def load_from_disk(cls, dataset_path: PathLike, **kwargs):
plaguss marked this conversation as resolved.
Show resolved Hide resolved
"""Load a CustomDataset from disk, also reading the task.

Args:
dataset_path: Path to the dataset, as you would do with a standard Dataset.

Returns:
The loaded dataset.
"""
ds = super().load_from_disk(dataset_path, *kwargs)
# Dynamically remaps the `datasets.Dataset` to be a `CustomDataset` instance
ds.__class__ = cls
task = load_task_from_disk(dataset_path)
ds.task = task
return ds
72 changes: 57 additions & 15 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_progress_bars_for_pipeline,
use_progress_bar,
)
from distilabel.utils.dataset import DatasetCheckpoint
from distilabel.utils.dicts import combine_dicts
from distilabel.utils.types import is_future

Expand Down Expand Up @@ -528,7 +529,7 @@ def _generate( # noqa: C901
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
checkpoint_strategy: Optional[DatasetCheckpoint] = DatasetCheckpoint(),
display_progress_bar: bool = False,
) -> CustomDataset:
"""Generates the outputs for the given dataset using the LLMs provided to the
Expand All @@ -543,8 +544,9 @@ def _generate( # noqa: C901
shuffle_before_labelling (bool, optional): whether to shuffle the generations
before labelling or not. This is useful to avoid the labelling LLM to be
biased by the order of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not.
Defaults to `True`.
checkpoint_strategy (DatasetCheckpoint, optional): the checkpoint strategy.
If `None` is provided, no checkpoints will be saved. Defaults to `DatasetCheckpoint()`,
which won't save the dataset but returns the generated dataset upon failure.
display_progress_bar (bool, optional): whether to display the progress bar
or not. Defaults to `False`.

Expand All @@ -554,7 +556,7 @@ def _generate( # noqa: C901
Raises:
RuntimeError: if the `Pipeline` fails during the generation or labelling steps.
UserWarning: if the `Pipeline` fails during the generation or labelling steps
and `enable_checkpoints` is set to `False`.
and `checkpoint_strategy` is set to `None`.

Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
Expand Down Expand Up @@ -623,20 +625,39 @@ def _generate( # noqa: C901
progress_callback_func=generation_progress_func,
)
generations.extend(batch_generations)
# If we have both generator and labeller, save only after the labeller
if checkpoint_strategy and (self.labeller is None):
if checkpoint_strategy.do_checkpoint(batch_i * batch_size):
logger.info(f"Saving dataset up to batch {batch_i}...")
ds = self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)
ds.save_to_disk(
checkpoint_strategy.path,
**checkpoint_strategy.extra_kwargs,
)
plaguss marked this conversation as resolved.
Show resolved Hide resolved

except Exception as e:
if not enable_checkpoints:
if not checkpoint_strategy:
raise RuntimeError(
"`Pipeline.generate` failed during generation step. Setting `enable_checkpoints=True` is recommended!"
"`Pipeline.generate` failed during generation step. Passing a `DatasetCheckpoint` is recommended!"
) from e
logger.error(
f"`Pipeline.generate` failed during generation step with exception: {e}"
)
return self._build_dataset(
ds = self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)
ds.save_to_disk(
checkpoint_strategy.path, **checkpoint_strategy.extra_kwargs
)
return ds

inputs = self._include_generator_outputs_as_inputs(
inputs=inputs, outputs=batch_generations
Expand All @@ -653,20 +674,39 @@ def _generate( # noqa: C901
labels.append(batch_labels) # type: ignore
else:
labels.extend(batch_labels) # type: ignore

if checkpoint_strategy:
if checkpoint_strategy.do_checkpoint(batch_i * batch_size):
logger.info(f"Saving dataset up to batch {batch_i}...")
ds = self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)
ds.save_to_disk(
checkpoint_strategy.path,
**checkpoint_strategy.extra_kwargs,
)

except Exception as e:
if not enable_checkpoints:
if not checkpoint_strategy:
raise RuntimeError(
"`Pipeline.generate` failed during labelling step. Setting `enable_checkpoints=True` is recommended!"
"`Pipeline.generate` failed during labelling step. Passing a `DatasetCheckpoint` is recommended!"
) from e
logger.error(
f"`Pipeline.generate` failed during labelling step with exception: {e}"
)
return self._build_dataset(
ds = self._build_dataset(
dataset,
generations=generations,
labels=labels,
batch_size=batch_size,
)
ds.save_to_disk(
checkpoint_strategy.path, **checkpoint_strategy.extra_kwargs
)
return ds

_pipeline_progress.stop()

Expand Down Expand Up @@ -697,7 +737,7 @@ def dry_run(self, dataset: Dataset) -> CustomDataset:
# Default kwargs to make the process as simple as possible
num_generations=1,
batch_size=1,
enable_checkpoints=False,
checkpoint_strategy=None,
display_progress_bar=False,
)
except Exception as e:
Expand All @@ -712,7 +752,7 @@ def generate(
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
checkpoint_strategy: Optional[DatasetCheckpoint] = DatasetCheckpoint(),
display_progress_bar: bool = False,
skip_dry_run: bool = False,
) -> CustomDataset:
Expand All @@ -726,7 +766,9 @@ def generate(
shuffle_before_labelling: whether to shuffle the generations before labelling
or not. This is useful to avoid the labelling LLM to be biased by the order
of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`.
checkpoint_strategy (DatasetCheckpoint, optional): the checkpoint strategy.
If `None` is provided, no checkpoints will be saved. Defaults to `DatasetCheckpoint()`,
which won't save the dataset but returns the generated dataset upon failure.
display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`.
skip_dry_run (bool, optional): whether to skip the dry run or not. Defaults to `False`.

Expand All @@ -736,7 +778,7 @@ def generate(
Raises:
RuntimeError: if the `Pipeline` fails during the generation or labelling steps.
UserWarning: if the `Pipeline` fails during the generation or labelling steps and
`enable_checkpoints` is set to `False`.
`checkpoint_strategy` is set to `None`.

Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
Expand Down Expand Up @@ -768,7 +810,7 @@ def generate(
dataset=dataset,
num_generations=num_generations,
batch_size=batch_size,
enable_checkpoints=enable_checkpoints,
checkpoint_strategy=checkpoint_strategy,
shuffle_before_labelling=shuffle_before_labelling,
display_progress_bar=display_progress_bar,
)
Expand Down
87 changes: 87 additions & 0 deletions src/distilabel/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict

if TYPE_CHECKING:
from distilabel.tasks.base import Task


def save_task_to_disk(path: Path, task: "Task") -> None:
"""Saves a task to disk.

Args:
path: The path to the task.
task: The task.
"""
task_path = path / "task.pkl"
plaguss marked this conversation as resolved.
Show resolved Hide resolved
with open(task_path, "wb") as f:
pickle.dump(task, f)


def load_task_from_disk(path: Path) -> "Task":
"""Loads a task from disk.

Args:
path: The path to the task.

Returns:
Task: The task.
"""
task_path = path / "task.pkl"
if not task_path.exists():
raise FileNotFoundError(f"The task file does not exist: {task_path}")
with open(task_path, "rb") as f:
task = pickle.load(f)
return task


@dataclass
class DatasetCheckpoint:
plaguss marked this conversation as resolved.
Show resolved Hide resolved
"""A checkpoint class that contains the information of a checkpoint.

Args:
path (Path): The path to the checkpoint.
save_frequency (int): The frequency at which the checkpoint should be saved
By default is set to -1 (no checkpoint is saved to disk, but the dataset
is returned upon failure).
extra_kwargs (dict[str, Any]): Additional kwargs to be passed to the `save_to_disk` method of the Dataset.
"""

path: Path = Path.cwd() / "dataset_checkpoint"
save_frequency: int = -1
extra_kwargs: Dict[str, Any] = field(default_factory=dict)

# Internal fields to keep track of the number of records generated and when to check.
_total_checks: int = field(repr=False, default=0)

def do_checkpoint(self, step: int) -> bool:
"""Determines if a checkpoint should be done.

Args:
step (int): The number of records generated.

Returns:
bool: Whether a checkpoint should be done.
"""
if self.save_frequency == -1:
return False

if (step - self._total_checks * self.save_frequency) // self.save_frequency:
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
self._total_checks += 1
return True
return False
57 changes: 57 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from pathlib import Path

import pytest
from distilabel.dataset import CustomDataset
from distilabel.tasks import UltraFeedbackTask
from distilabel.utils.dataset import DatasetCheckpoint


@pytest.fixture
def custom_dataset():
ds = CustomDataset.from_dict({"input": ["a", "b"], "generations": ["c", "d"]})
ds.task = UltraFeedbackTask.for_text_quality()
return ds


def test_dataset_save_to_disk(custom_dataset):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
custom_dataset.save_to_disk(ds_name)
assert ds_name.is_dir()
assert (ds_name / "task.pkl").is_file()


def test_dataset_load_disk(custom_dataset):
with tempfile.TemporaryDirectory() as tmpdir:
ds_name = Path(tmpdir) / "dataset_folder"
custom_dataset.save_to_disk(ds_name)
ds_from_disk = CustomDataset.load_from_disk(ds_name)
assert isinstance(ds_from_disk, CustomDataset)
assert isinstance(ds_from_disk.task, UltraFeedbackTask)


def test_do_checkpoint():
chk = DatasetCheckpoint(save_frequency=2)
assert chk.do_checkpoint(0) is False
assert chk._total_checks == 0
assert chk.do_checkpoint(2) is True
assert chk._total_checks == 1
assert chk.do_checkpoint(3) is False
assert chk._total_checks == 1
assert chk.do_checkpoint(4) is True
assert chk._total_checks == 2