diff --git a/examples/pipeline-prometheus.py b/examples/pipeline-prometheus.py index 0a886f61ad..4a8f94e187 100644 --- a/examples/pipeline-prometheus.py +++ b/examples/pipeline-prometheus.py @@ -18,7 +18,7 @@ from datasets import Dataset from distilabel.llm import TransformersLLM from distilabel.pipeline import Pipeline -from distilabel.tasks.critique.prometheus import PrometheusTask +from distilabel.tasks import PrometheusTask from transformers import AutoTokenizer, LlamaForCausalLM if __name__ == "__main__": diff --git a/src/distilabel/tasks/__init__.py b/src/distilabel/tasks/__init__.py index 589d893d8b..8723e2c422 100644 --- a/src/distilabel/tasks/__init__.py +++ b/src/distilabel/tasks/__init__.py @@ -14,6 +14,7 @@ from distilabel.tasks.base import Task from distilabel.tasks.critique.base import CritiqueTask +from distilabel.tasks.critique.prometheus import PrometheusTask from distilabel.tasks.critique.ultracm import UltraCMTask from distilabel.tasks.preference.judgelm import JudgeLMTask from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask @@ -27,6 +28,7 @@ __all__ = [ "Task", "CritiqueTask", + "PrometheusTask", "UltraCMTask", "JudgeLMTask", "UltraFeedbackTask", diff --git a/src/distilabel/tasks/critique/base.py b/src/distilabel/tasks/critique/base.py index cf3bd016b7..8c79828888 100644 --- a/src/distilabel/tasks/critique/base.py +++ b/src/distilabel/tasks/critique/base.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Literal +from typing import ClassVar, List, Literal from typing_extensions import TypedDict @@ -29,7 +29,7 @@ class CritiqueTask(Task): task_description (Union[str, None], optional): the description of the task. Defaults to `None`. """ - __type__: Literal["labelling"] = "labelling" + __type__: ClassVar[Literal["labelling"]] = "labelling" @property def input_args_names(self) -> List[str]: diff --git a/src/distilabel/tasks/preference/base.py b/src/distilabel/tasks/preference/base.py index 68b1a4c95d..f830341927 100644 --- a/src/distilabel/tasks/preference/base.py +++ b/src/distilabel/tasks/preference/base.py @@ -14,7 +14,7 @@ import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional from distilabel.tasks.base import Task from distilabel.utils.argilla import ( @@ -40,6 +40,8 @@ class PreferenceTask(Task): task_description (Union[str, None], optional): the description of the task. Defaults to `None`. """ + __type__: ClassVar[Literal["labelling"]] = "labelling" + @property def input_args_names(self) -> List[str]: """Returns the names of the input arguments of the task.""" diff --git a/src/distilabel/tasks/preference/judgelm.py b/src/distilabel/tasks/preference/judgelm.py index a170c1286a..e9d8aa6bf4 100644 --- a/src/distilabel/tasks/preference/judgelm.py +++ b/src/distilabel/tasks/preference/judgelm.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import ClassVar, List, Literal, TypedDict +from typing import ClassVar, List, TypedDict from distilabel.tasks.base import get_template from distilabel.tasks.preference.base import PreferenceTask @@ -51,7 +51,6 @@ class JudgeLMTask(PreferenceTask): system_prompt: str = "You are a helpful and precise assistant for checking the quality of the answer." __jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE - __type__: Literal["labelling"] = "labelling" def generate_prompt(self, input: str, generations: List[str]) -> Prompt: """Generates a prompt following the JudgeLM specification. diff --git a/src/distilabel/tasks/preference/ultrafeedback.py b/src/distilabel/tasks/preference/ultrafeedback.py index 2cd7c485ea..39254105d5 100644 --- a/src/distilabel/tasks/preference/ultrafeedback.py +++ b/src/distilabel/tasks/preference/ultrafeedback.py @@ -20,7 +20,6 @@ ClassVar, Dict, List, - Literal, Optional, TypedDict, ) @@ -76,7 +75,6 @@ class UltraFeedbackTask(PreferenceTask): "honesty", "instruction-following", ] - __type__: Literal["labelling"] = "labelling" def generate_prompt(self, input: str, generations: List[str]) -> Prompt: """Generates a prompt following the ULTRAFEEDBACK specification. diff --git a/src/distilabel/tasks/preference/ultrajudge.py b/src/distilabel/tasks/preference/ultrajudge.py index e2fdbf751f..421fa57333 100644 --- a/src/distilabel/tasks/preference/ultrajudge.py +++ b/src/distilabel/tasks/preference/ultrajudge.py @@ -14,7 +14,7 @@ import re from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, List, Literal, TypedDict +from typing import Any, ClassVar, Dict, List, TypedDict from distilabel.tasks.base import Prompt, get_template from distilabel.tasks.preference.base import PreferenceTask @@ -79,7 +79,6 @@ class UltraJudgeTask(PreferenceTask): __jinja2_template__: ClassVar[str] = field( default=_ULTRAJUDGE_TEMPLATE, init=False, repr=False ) - __type__: Literal["labelling"] = "labelling" @property def output_args_names(self) -> List[str]: