Skip to content

Commit

Permalink
Add checking valid inputs before calling _generate (#216)
Browse files Browse the repository at this point in the history
* Add `TogetherInferenceLLM` and `_TOGETHER_AVAILABLE_FLAG`

* Update docstrings of `TogetherInferenceLLM`

* Add `TogetherInferenceLLM` in `distilabel.llm` init

* Access `TogetherInferenceLLM` output via dict

* Fix bug affecting `TextGenerationTask` in `_to_argilla_record`

* Add `TogetherInferenceLLM` documentation

* Add `examples/pipeline-together-inference.py`

* Add checking valid `inputs` for `LLM.generate`

* Add getting and filling invalid inputs

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
gabrielmbmb and alvarobartt authored Jan 5, 2024
1 parent 53ca00c commit b1949b6
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 69 deletions.
85 changes: 81 additions & 4 deletions src/distilabel/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
Dict,
Generator,
List,
Tuple,
Union,
)

import multiprocess as mp

from distilabel.llm.utils import LLMOutput
from distilabel.logger import get_logger
from distilabel.tasks.prompt import Prompt
from distilabel.utils.futures import when_all_complete

if TYPE_CHECKING:
from distilabel.llm.utils import LLMOutput
from distilabel.tasks.base import Task
from distilabel.tasks.prompt import SupportedFormats

Expand Down Expand Up @@ -167,6 +168,70 @@ def _generate(
) -> List[List["LLMOutput"]]:
pass

def _get_valid_inputs(
self, inputs: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""Returns the valid inputs and the indices of the invalid inputs.
A valid input is an input that contains all the arguments required by the task.
Args:
inputs (List[Dict[str, Any]]): the inputs to be used for generation.
Returns:
Tuple[List[Dict[str, Any]], List[int]]: a tuple containing the valid inputs and
the indices of the invalid inputs.
"""

valid_inputs = []
not_valid_inputs_indices = []
for i, input in enumerate(inputs):
if not all(input_arg in input for input_arg in self.task.input_args_names):
logger.warn(
f"Missing {self.task.__class__.__name__} input argument in batch element {i}"
)
not_valid_inputs_indices.append(i)
continue

valid_inputs.append(input)

return valid_inputs, not_valid_inputs_indices

def _fill_missing_inputs(
self,
generations: List[List[LLMOutput]],
invalid_inputs_indices: List[int],
num_generations: int,
) -> List[List[LLMOutput]]:
"""Fills the `generations` list with empty `LLMOutput`s for the inputs that were
not valid for the associated task of this `LLM`.
Args:
generations (List[List[LLMOutput]]): the generations to be filled.
invalid_inputs_indices (List[int]): the indices of the inputs that were not
valid for the associated task of this `LLM`.
num_generations (int): the number of generations to be performed for each input.
Returns:
List[List[LLMOutput]]: the filled generations.
"""

filled_generations = generations.copy()
for idx in invalid_inputs_indices:
filled_generations.insert(
idx,
[
LLMOutput(
model_name=self.model_name,
prompt_used=None,
raw_output=None,
parsed_output=None,
)
for _ in range(num_generations)
],
)
return filled_generations

def generate(
self,
inputs: List[Dict[str, Any]],
Expand All @@ -190,18 +255,30 @@ def _progress():
if progress_callback_func is not None:
progress_callback_func(advance=num_generations * len(inputs))

valid_inputs, invalid_inputs_indices = self._get_valid_inputs(inputs)

if self.thread_pool_executor is not None:
futures = []
for input in inputs:
for input in valid_inputs:
future = self.thread_pool_executor.submit(
self._generate, [input], num_generations
)
futures.append(future)
future = when_all_complete(futures)
future = when_all_complete(
futures=futures,
callback=lambda generations: self._fill_missing_inputs(
generations, invalid_inputs_indices, num_generations
),
)
future.add_done_callback(lambda _: _progress())
return future

generations = self._generate(inputs, num_generations)
generations = self._generate(valid_inputs, num_generations)

generations = self._fill_missing_inputs(
generations, invalid_inputs_indices, num_generations
)

_progress()
return generations

Expand Down
7 changes: 6 additions & 1 deletion src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,12 @@ def _process_batch_generations(
try:
processed_generation.update(
**combine_dicts(
*[generation["parsed_output"] for generation in generations]
*[
generation["parsed_output"]
if generation["parsed_output"] is not None
else {}
for generation in generations
]
)
)
except Exception as e:
Expand Down
13 changes: 9 additions & 4 deletions src/distilabel/utils/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from __future__ import annotations

from concurrent.futures import Future, wait
from typing import List
from typing import Callable, List, Optional

from typing_extensions import TypeVar

T = TypeVar("T")


def when_all_complete(futures: List[Future[T]]) -> Future[T]:
def when_all_complete(
futures: List[Future[T]], callback: Optional[Callable[[List[T]], List[T]]] = None
) -> Future[List[T]]:
"""Returns a `Future` that will be completed when all the provided `futures` are
completed, and it will contain the results of the `futures`.
Expand All @@ -34,7 +36,7 @@ def when_all_complete(futures: List[Future[T]]) -> Future[T]:
completed, and it will contain the results of the `futures`.
"""
all_done_future = Future()
results = [None] * len(futures)
results: List[T] = [None] * len(futures) # type: ignore

def check_all_done(future: Future) -> None:
# This is done to preserve the order of the results with respect to the order
Expand All @@ -44,7 +46,10 @@ def check_all_done(future: Future) -> None:

_, not_done = wait(futures, return_when="FIRST_COMPLETED")
if len(not_done) == 0:
all_done_future.set_result(results)
final_results = results
if callback is not None:
final_results = callback(results)
all_done_future.set_result(final_results)

for future in futures:
future.add_done_callback(check_all_done)
Expand Down
177 changes: 117 additions & 60 deletions tests/llm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import re
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple

import pytest
from distilabel.llm.base import LLM, LLMPool, ProcessLLM
from distilabel.llm.utils import LLMOutput
from distilabel.tasks.base import Task
from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask
Expand Down Expand Up @@ -58,71 +59,127 @@ def generate_prompt(self, input: str) -> "Prompt":
)


def test_llmpool_errors_if_llms_less_than_two() -> None:
with pytest.raises(ValueError, match="The `llms` argument must contain at least 2"):
LLMPool(llms=[None]) # type: ignore


def test_llmpool_errors_if_llm_not_instance_of_processllm() -> None:
with pytest.raises(
ValueError, match="The `llms` argument must contain only `ProcessLLM`s."
):
LLMPool(llms=[None, None]) # type: ignore


@pytest.mark.parametrize(
"tasks",
[
(TextGenerationTask(), TextGenerationTask()),
(TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), DummySubtask(), DummySubtask()),
],
)
def test_llmpool_with_subclass_of_tasks(tasks) -> None:
LLMPool(
llms=[
ProcessLLM(task=t, load_llm_fn=lambda task: DummyLLM(task=task))
for t in tasks
class TestLLM:
def test_get_valid_inputs(self) -> None:
llm = DummyLLM(task=TextGenerationTask())

inputs = [
{"input": "I'm valid for text generation task"},
{"random": "I'm not valid"},
]
valid_inputs, invalid_inputs_indices = llm._get_valid_inputs(inputs=inputs)
assert valid_inputs == [{"input": "I'm valid for text generation task"}]
assert invalid_inputs_indices == [1]

def test_fill_missing_inputs(self) -> None:
llm = DummyLLM(task=TextGenerationTask())

generations = [
[
LLMOutput(
model_name=llm.model_name,
prompt_used=llm.task.generate_prompt(
input="I'm valid for text generation task"
).format_as("default"),
raw_output="dummy",
parsed_output="dummy",
),
LLMOutput(
model_name=llm.model_name,
prompt_used=llm.task.generate_prompt(
input="I'm valid too"
).format_as("default"),
raw_output="dummy",
parsed_output="dummy",
),
]
]
)

filled_generations = llm._fill_missing_inputs(
generations=generations,
invalid_inputs_indices=[1],
num_generations=2,
)

def test_llmpool_errors_if_llms_do_not_have_same_task() -> None:
llm1 = ProcessLLM(
task=TextGenerationTask(), load_llm_fn=lambda task: DummyLLM(task=task)
)
llm2 = ProcessLLM(
task=UltraFeedbackTask.for_honesty(),
load_llm_fn=lambda task: DummyLLM(task=task),
assert filled_generations == generations + [
[
LLMOutput(
model_name=llm.model_name,
prompt_used=None,
raw_output=None,
parsed_output=None,
)
for _ in range(2)
]
]


class TestLLMPool:
def test_llmpool_errors_if_llms_less_than_two(self) -> None:
with pytest.raises(
ValueError, match="The `llms` argument must contain at least 2"
):
LLMPool(llms=[None]) # type: ignore

def test_llmpool_errors_if_llm_not_instance_of_processllm(self) -> None:
with pytest.raises(
ValueError, match="The `llms` argument must contain only `ProcessLLM`s."
):
LLMPool(llms=[None, None]) # type: ignore

@pytest.mark.parametrize(
"tasks",
[
(TextGenerationTask(), TextGenerationTask()),
(TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), DummySubtask(), DummySubtask()),
],
)
with pytest.raises(
ValueError,
match=re.escape(
"All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
),
):
LLMPool(llms=[llm1, llm2])


@pytest.mark.parametrize(
"num_generations, num_llms, expected", [(2, 4, {0, 1}), (4, 4, {1}), (9, 4, {2, 3})]
)
def test_llmpool_get_num_generations_per_llm(
num_generations: int, num_llms: int, expected: Set[int]
) -> None:
llms = []
for _ in range(num_llms):
llms.append(
ProcessLLM(
task=TextGenerationTask(), load_llm_fn=lambda task: DummyLLM(task=task)
)
def test_llmpool_with_subclass_of_tasks(self, tasks: Tuple[Task]) -> None:
LLMPool(
llms=[
ProcessLLM(task=t, load_llm_fn=lambda task: DummyLLM(task=task))
for t in tasks
]
)

pool = LLMPool(llms=llms)
def test_llmpool_errors_if_llms_do_not_have_same_task(self) -> None:
llm1 = ProcessLLM(
task=TextGenerationTask(), load_llm_fn=lambda task: DummyLLM(task=task)
)
llm2 = ProcessLLM(
task=UltraFeedbackTask.for_honesty(),
load_llm_fn=lambda task: DummyLLM(task=task),
)
with pytest.raises(
ValueError,
match=re.escape(
"All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
),
):
LLMPool(llms=[llm1, llm2])

num_generations_per_llm = pool._get_num_generations_per_llm(
num_generations=num_generations
@pytest.mark.parametrize(
"num_generations, num_llms, expected",
[(2, 4, {0, 1}), (4, 4, {1}), (9, 4, {2, 3})],
)
def test_llmpool_get_num_generations_per_llm(
self, num_generations: int, num_llms: int, expected: Set[int]
) -> None:
llms = []
for _ in range(num_llms):
llms.append(
ProcessLLM(
task=TextGenerationTask(),
load_llm_fn=lambda task: DummyLLM(task=task),
)
)

pool = LLMPool(llms=llms)

num_generations_per_llm = pool._get_num_generations_per_llm(
num_generations=num_generations
)

assert set(num_generations_per_llm.values()) == expected
assert set(num_generations_per_llm.values()) == expected

0 comments on commit b1949b6

Please sign in to comment.