Skip to content

Commit

Permalink
Fix IndexError when overriding inputs and group_generations=False (
Browse files Browse the repository at this point in the history
…#1022)

* Fix processing num_generations when applying input mappings in steps process

* Add unit test

* Update comment

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Oct 8, 2024
1 parent ebab004 commit 4cbcb90
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,19 @@ def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput"
)

for output_rows in generator:
yield [
self._apply_mappings_and_restore_overriden(row, overriden_inputs[i])
for i, row in enumerate(output_rows)
]
restored = []
for i, row in enumerate(output_rows):
# Correct the index here because we don't know the num_generations from the llm
# ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
# from `num_generations==2` and `group_generations=False` in the LLM:
# The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
ntimes_i = i % len(overriden_inputs)
restored.append(
self._apply_mappings_and_restore_overriden(
row, overriden_inputs[ntimes_i]
)
)
yield restored

def _apply_input_mappings(
self, inputs: Tuple[List[Dict[str, Any]], ...]
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,94 @@ def test_process(
result = next(task.process(input))
assert result == expected

def test_process_overriding_inputs(self) -> None:
llm = DummyAsyncLLM()
task = DummyTask(
name="task",
llm=llm,
group_generations=False,
num_generations=3,
input_mappings={"instruction": "instruction_2"},
)

result = next(
task.process_applying_mappings(
[
{
"instruction": "instruction that won't be used but overriden by input mapping",
"instruction_2": "instruction that will be used as input",
"additional_info": "info",
}
]
)
)

assert result == [
{
"additional_info": "info",
"distilabel_metadata": {
"raw_input_task": [
{
"content": "",
"role": "system",
},
{
"content": "instruction that will be used as input",
"role": "user",
},
],
"raw_output_task": "output",
},
"info_from_input": "info",
"instruction": "instruction that won't be used but overriden by input mapping",
"instruction_2": "instruction that will be used as input",
"model_name": "test",
"output": "output",
},
{
"additional_info": "info",
"distilabel_metadata": {
"raw_input_task": [
{
"content": "",
"role": "system",
},
{
"content": "instruction that will be used as input",
"role": "user",
},
],
"raw_output_task": "output",
},
"info_from_input": "info",
"instruction": "instruction that won't be used but overriden by input mapping",
"instruction_2": "instruction that will be used as input",
"model_name": "test",
"output": "output",
},
{
"additional_info": "info",
"distilabel_metadata": {
"raw_input_task": [
{
"content": "",
"role": "system",
},
{
"content": "instruction that will be used as input",
"role": "user",
},
],
"raw_output_task": "output",
},
"info_from_input": "info",
"instruction": "instruction that won't be used but overriden by input mapping",
"instruction_2": "instruction that will be used as input",
"model_name": "test",
"output": "output",
},
]

def test_process_with_runtime_parameters(self) -> None:
# 1. Runtime parameters provided
llm = DummyRuntimeLLM() # type: ignore
Expand Down

0 comments on commit 4cbcb90

Please sign in to comment.