Skip to content

Commit

Permalink
Relax LLMPool check to match parent Task instead (#210)
Browse files Browse the repository at this point in the history
* Update llm tasks' check to allow different subclasses of the same class

* Update src/distilabel/llm/base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Add extra tests for the LLM subtasks

---------

Co-authored-by: plaguss <[email protected]>
Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
3 people authored Jan 5, 2024
1 parent f29c969 commit 854e33e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/distilabel/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,13 @@ def __init__(self, llms: List[ProcessLLM]) -> None:
if not all(isinstance(llm, ProcessLLM) for llm in llms):
raise ValueError("The `llms` argument must contain only `ProcessLLM`s.")

if not all(llm.task == llms[0].task for llm in llms):
# Note: The following piece of code is used to check that all the `ProcessLLM`s
# have the same task or a subclass of it.
mros = [(type(llm.task), len(type(llm.task).mro())) for llm in llms]
min_common_class = min(mros, key=lambda x: x[1])[0]
if not all(isinstance(llm.task, min_common_class) for llm in llms):
raise ValueError(
"The `llms` argument must contain `ProcessLLM`s with the same task."
"All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
)

self.llms = llms
Expand Down
36 changes: 35 additions & 1 deletion tests/llm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Any, Dict, List, Set

import pytest
from distilabel.llm.base import LLM, LLMPool, ProcessLLM
from distilabel.llm.utils import LLMOutput
from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask


Expand All @@ -44,6 +46,18 @@ def _generate(
return outputs


class DummySubtask(TextGenerationTask):
system_prompt: str = "You are a helpful assistant."

def generate_prompt(self, input: str) -> "Prompt":
return Prompt(
system_prompt=self.system_prompt,
formatted_prompt="Instruction {instruction}\nResponse".format(
instruction=input
),
)


def test_llmpool_errors_if_llms_less_than_two() -> None:
with pytest.raises(ValueError, match="The `llms` argument must contain at least 2"):
LLMPool(llms=[None]) # type: ignore
Expand All @@ -56,6 +70,24 @@ def test_llmpool_errors_if_llm_not_instance_of_processllm() -> None:
LLMPool(llms=[None, None]) # type: ignore


@pytest.mark.parametrize(
"tasks",
[
(TextGenerationTask(), TextGenerationTask()),
(TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), TextGenerationTask(), DummySubtask()),
(TextGenerationTask(), DummySubtask(), DummySubtask()),
],
)
def test_llmpool_with_subclass_of_tasks(tasks) -> None:
LLMPool(
llms=[
ProcessLLM(task=t, load_llm_fn=lambda task: DummyLLM(task=task))
for t in tasks
]
)


def test_llmpool_errors_if_llms_do_not_have_same_task() -> None:
llm1 = ProcessLLM(
task=TextGenerationTask(), load_llm_fn=lambda task: DummyLLM(task=task)
Expand All @@ -66,7 +98,9 @@ def test_llmpool_errors_if_llms_do_not_have_same_task() -> None:
)
with pytest.raises(
ValueError,
match="The `llms` argument must contain `ProcessLLM`s with the same task.",
match=re.escape(
"All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
),
):
LLMPool(llms=[llm1, llm2])

Expand Down

0 comments on commit 854e33e

Please sign in to comment.