Skip to content

Commit

Permalink
Fix PrometheusTask could not be imported (#190)
Browse files Browse the repository at this point in the history
* Fix `__type__` was not a `ClassVar`

`__type__` attribute in `Task` base classes was not a `ClassVar` causing
dataclasses inheriting from them to not be able to have non-default
attributes.

* Add missing `PrometheusTask` import

* Update `PrometheusTask` import
  • Loading branch information
gabrielmbmb authored Dec 22, 2023
1 parent e682616 commit 7fe94b9
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/pipeline-prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datasets import Dataset
from distilabel.llm import TransformersLLM
from distilabel.pipeline import Pipeline
from distilabel.tasks.critique.prometheus import PrometheusTask
from distilabel.tasks import PrometheusTask
from transformers import AutoTokenizer, LlamaForCausalLM

if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from distilabel.tasks.base import Task
from distilabel.tasks.critique.base import CritiqueTask
from distilabel.tasks.critique.prometheus import PrometheusTask
from distilabel.tasks.critique.ultracm import UltraCMTask
from distilabel.tasks.preference.judgelm import JudgeLMTask
from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
Expand All @@ -27,6 +28,7 @@
__all__ = [
"Task",
"CritiqueTask",
"PrometheusTask",
"UltraCMTask",
"JudgeLMTask",
"UltraFeedbackTask",
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/tasks/critique/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import List, Literal
from typing import ClassVar, List, Literal

from typing_extensions import TypedDict

Expand All @@ -29,7 +29,7 @@ class CritiqueTask(Task):
task_description (Union[str, None], optional): the description of the task. Defaults to `None`.
"""

__type__: Literal["labelling"] = "labelling"
__type__: ClassVar[Literal["labelling"]] = "labelling"

@property
def input_args_names(self) -> List[str]:
Expand Down
4 changes: 3 additions & 1 deletion src/distilabel/tasks/preference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional

from distilabel.tasks.base import Task
from distilabel.utils.argilla import (
Expand All @@ -40,6 +40,8 @@ class PreferenceTask(Task):
task_description (Union[str, None], optional): the description of the task. Defaults to `None`.
"""

__type__: ClassVar[Literal["labelling"]] = "labelling"

@property
def input_args_names(self) -> List[str]:
"""Returns the names of the input arguments of the task."""
Expand Down
3 changes: 1 addition & 2 deletions src/distilabel/tasks/preference/judgelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import ClassVar, List, Literal, TypedDict
from typing import ClassVar, List, TypedDict

from distilabel.tasks.base import get_template
from distilabel.tasks.preference.base import PreferenceTask
Expand Down Expand Up @@ -51,7 +51,6 @@ class JudgeLMTask(PreferenceTask):
system_prompt: str = "You are a helpful and precise assistant for checking the quality of the answer."

__jinja2_template__: ClassVar[str] = _JUDGELM_TEMPLATE
__type__: Literal["labelling"] = "labelling"

def generate_prompt(self, input: str, generations: List[str]) -> Prompt:
"""Generates a prompt following the JudgeLM specification.
Expand Down
2 changes: 0 additions & 2 deletions src/distilabel/tasks/preference/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ClassVar,
Dict,
List,
Literal,
Optional,
TypedDict,
)
Expand Down Expand Up @@ -76,7 +75,6 @@ class UltraFeedbackTask(PreferenceTask):
"honesty",
"instruction-following",
]
__type__: Literal["labelling"] = "labelling"

def generate_prompt(self, input: str, generations: List[str]) -> Prompt:
"""Generates a prompt following the ULTRAFEEDBACK specification.
Expand Down
3 changes: 1 addition & 2 deletions src/distilabel/tasks/preference/ultrajudge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, List, Literal, TypedDict
from typing import Any, ClassVar, Dict, List, TypedDict

from distilabel.tasks.base import Prompt, get_template
from distilabel.tasks.preference.base import PreferenceTask
Expand Down Expand Up @@ -79,7 +79,6 @@ class UltraJudgeTask(PreferenceTask):
__jinja2_template__: ClassVar[str] = field(
default=_ULTRAJUDGE_TEMPLATE, init=False, repr=False
)
__type__: Literal["labelling"] = "labelling"

@property
def output_args_names(self) -> List[str]:
Expand Down

0 comments on commit 7fe94b9

Please sign in to comment.