Skip to content

Commit

Permalink
Add shuffle_before_labelling parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Dec 19, 2023
1 parent 94e0219 commit 263ef63
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _get_batch_generations(
inputs: List[Dict[str, Any]],
num_generations: int,
num_batches: int,
shuffle_before_labelling: bool = True,
progress_callback_func: Union[Callable, None] = None,
) -> List[Dict[str, Any]]:
"""Gets the batch generations for the given inputs, capturing the futures if the
Expand All @@ -190,6 +191,10 @@ def _get_batch_generations(
inputs (List[Dict[str, Any]]): the inputs to be used for generation.
num_generations (int): the number of generations to be performed for each
input.
num_batches (int): the number of batches to be processed.
shuffle_before_labelling (bool, optional): whether to shuffle the generations
before labelling or not. This is useful to avoid the labelling LLM to be
biased by the order of the generations. Defaults to `True`.
progress_callback_func (Union[Callable, None], optional): the callback function
to be called when the progress of the generation process changes. Defaults
to None.
Expand Down Expand Up @@ -228,7 +233,10 @@ def _get_batch_generations(
)
else:
batch_generations = outputs
return self._process_batch_generations(batch_generations=batch_generations)
return self._process_batch_generations(
batch_generations=batch_generations,
shuffle_before_labelling=shuffle_before_labelling,
)

def _get_batch_labels(
self,
Expand Down Expand Up @@ -260,12 +268,16 @@ def _get_batch_labels(
def _process_batch_generations(
self,
batch_generations: List[List["LLMOutput"]],
shuffle_before_labelling: bool = True,
) -> List[Dict[str, Any]]:
"""Processes the batch generations, combining the outputs of the LLMs into a single
dictionary.
Args:
batch_generations (List[List["LLMOutput"]]): the batch generations to be processed.
shuffle_before_labelling (bool, optional): whether to shuffle the generations
before labelling or not. This is useful to avoid the labelling LLM to be
biased by the order of the generations. Defaults to `True`.
Returns:
List[Dict[str, Any]]: the processed batch generations.
Expand All @@ -277,7 +289,8 @@ def _process_batch_generations(
"generation_prompt": [],
"raw_generation_responses": [],
}
random.shuffle(generations)
if shuffle_before_labelling:
random.shuffle(generations)
for generation in generations:
processed_generation["generation_model"].append(
generation["model_name"]
Expand Down Expand Up @@ -533,26 +546,34 @@ def _generate( # noqa: C901
dataset: Dataset,
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
display_progress_bar: bool = False,
) -> CustomDataset:
"""Generates the outputs for the given dataset using the LLMs provided to the `Pipeline`.
"""Generates the outputs for the given dataset using the LLMs provided to the
`Pipeline`.
Args:
dataset (Dataset): the dataset to be used for generation.
num_generations (int, optional): the number of generations to be performed for each
input. Defaults to `1`.
batch_size (int, optional): the batch size to be used for generation. Defaults to `1`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`.
display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`.
num_generations (int, optional): the number of generations to be performed
for each input. Defaults to `1`.
batch_size (int, optional): the batch size to be used for generation. Defaults
to `1`.
shuffle_before_labelling (bool, optional): whether to shuffle the generations
before labelling or not. This is useful to avoid the labelling LLM to be
biased by the order of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not.
Defaults to `True`.
display_progress_bar (bool, optional): whether to display the progress bar
or not. Defaults to `False`.
Returns:
CustomDataset: the final dataset.
Raises:
RuntimeError: if the `Pipeline` fails during the generation or labelling steps.
UserWarning: if the `Pipeline` fails during the generation or labelling steps and
`enable_checkpoints` is set to `False`.
UserWarning: if the `Pipeline` fails during the generation or labelling steps
and `enable_checkpoints` is set to `False`.
Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
Expand Down Expand Up @@ -706,6 +727,7 @@ def generate(
dataset: Dataset,
num_generations: int = 1,
batch_size: int = 1,
shuffle_before_labelling: bool = True,
enable_checkpoints: bool = True,
display_progress_bar: bool = False,
skip_dry_run: bool = False,
Expand All @@ -717,6 +739,9 @@ def generate(
num_generations (int, optional): the number of generations to be performed for each
input. Defaults to `1`.
batch_size (int, optional): the batch size to be used for generation. Defaults to `1`.
shuffle_before_labelling: whether to shuffle the generations before labelling
or not. This is useful to avoid the labelling LLM to be biased by the order
of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not. Defaults to `True`.
display_progress_bar (bool, optional): whether to display the progress bar or not. Defaults to `False`.
skip_dry_run (bool, optional): whether to skip the dry run or not. Defaults to `False`.
Expand Down

0 comments on commit 263ef63

Please sign in to comment.