From 854e33e6ab0786bc18de37f018e5e55b0b0fedcc Mon Sep 17 00:00:00 2001 From: Agus <56895847+plaguss@users.noreply.github.com> Date: Fri, 5 Jan 2024 10:22:52 +0100 Subject: [PATCH] Relax `LLMPool` check to match parent `Task` instead (#210) * Update llm tasks' check to allow different subclasses of the same class * Update src/distilabel/llm/base.py Co-authored-by: Alvaro Bartolome * Add extra tests for the LLM subtasks --------- Co-authored-by: plaguss Co-authored-by: Alvaro Bartolome --- src/distilabel/llm/base.py | 8 ++++++-- tests/llm/test_base.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/distilabel/llm/base.py b/src/distilabel/llm/base.py index b6064fce21..d53646e068 100644 --- a/src/distilabel/llm/base.py +++ b/src/distilabel/llm/base.py @@ -637,9 +637,13 @@ def __init__(self, llms: List[ProcessLLM]) -> None: if not all(isinstance(llm, ProcessLLM) for llm in llms): raise ValueError("The `llms` argument must contain only `ProcessLLM`s.") - if not all(llm.task == llms[0].task for llm in llms): + # Note: The following piece of code is used to check that all the `ProcessLLM`s + # have the same task or a subclass of it. + mros = [(type(llm.task), len(type(llm.task).mro())) for llm in llms] + min_common_class = min(mros, key=lambda x: x[1])[0] + if not all(isinstance(llm.task, min_common_class) for llm in llms): raise ValueError( - "The `llms` argument must contain `ProcessLLM`s with the same task." + "All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)." ) self.llms = llms diff --git a/tests/llm/test_base.py b/tests/llm/test_base.py index 4c15211ca0..cfb97bfc8e 100644 --- a/tests/llm/test_base.py +++ b/tests/llm/test_base.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any, Dict, List, Set import pytest from distilabel.llm.base import LLM, LLMPool, ProcessLLM from distilabel.llm.utils import LLMOutput from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask +from distilabel.tasks.prompt import Prompt from distilabel.tasks.text_generation.base import TextGenerationTask @@ -44,6 +46,18 @@ def _generate( return outputs +class DummySubtask(TextGenerationTask): + system_prompt: str = "You are a helpful assistant." + + def generate_prompt(self, input: str) -> "Prompt": + return Prompt( + system_prompt=self.system_prompt, + formatted_prompt="Instruction {instruction}\nResponse".format( + instruction=input + ), + ) + + def test_llmpool_errors_if_llms_less_than_two() -> None: with pytest.raises(ValueError, match="The `llms` argument must contain at least 2"): LLMPool(llms=[None]) # type: ignore @@ -56,6 +70,24 @@ def test_llmpool_errors_if_llm_not_instance_of_processllm() -> None: LLMPool(llms=[None, None]) # type: ignore +@pytest.mark.parametrize( + "tasks", + [ + (TextGenerationTask(), TextGenerationTask()), + (TextGenerationTask(), DummySubtask()), + (TextGenerationTask(), TextGenerationTask(), DummySubtask()), + (TextGenerationTask(), DummySubtask(), DummySubtask()), + ], +) +def test_llmpool_with_subclass_of_tasks(tasks) -> None: + LLMPool( + llms=[ + ProcessLLM(task=t, load_llm_fn=lambda task: DummyLLM(task=task)) + for t in tasks + ] + ) + + def test_llmpool_errors_if_llms_do_not_have_same_task() -> None: llm1 = ProcessLLM( task=TextGenerationTask(), load_llm_fn=lambda task: DummyLLM(task=task) @@ -66,7 +98,9 @@ def test_llmpool_errors_if_llms_do_not_have_same_task() -> None: ) with pytest.raises( ValueError, - match="The `llms` argument must contain `ProcessLLM`s with the same task.", + match=re.escape( + "All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)." + ), ): LLMPool(llms=[llm1, llm2])