Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 230 feature add sentence transformers support for the to argilla method #262

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
426e418
chore: initial outline adding vectors
davidberenstein1957 Jan 16, 2024
d82660c
chore: added sentence-transformers to extras
davidberenstein1957 Jan 16, 2024
13e9b80
chore: aded `add_vectors_to_argilla_dataset` method to generation tasks
davidberenstein1957 Jan 16, 2024
962d0c7
chore: updated to_argilla methods with call to `Task.add_vectors_to_a…
davidberenstein1957 Jan 17, 2024
c13c2ee
chore: updated warning
davidberenstein1957 Jan 17, 2024
69119c8
chore: resolved linting issues
davidberenstein1957 Jan 17, 2024
415b93d
Update pyproject.toml
davidberenstein1957 Jan 17, 2024
1b38e79
chore: move logic to `CustomDataset` class
davidberenstein1957 Jan 18, 2024
c7256ff
tests: added basic tests for `to_argilla` method
davidberenstein1957 Jan 18, 2024
1729d03
docs: updated references to `vector_strategy`
davidberenstein1957 Jan 18, 2024
79d3105
Merge branch 'main' into feat/230-feature-add-sentence-transformers-s…
davidberenstein1957 Jan 22, 2024
316a287
chore: reformatted tests
davidberenstein1957 Jan 22, 2024
49df397
chore: remove cache step
davidberenstein1957 Jan 22, 2024
e65bf43
chore: limit to using 5 fields and defaulting to the first 5
davidberenstein1957 Jan 22, 2024
de9be16
tests: resolved failing tests
davidberenstein1957 Jan 23, 2024
0b1262b
tests: remove faulty type hint
davidberenstein1957 Jan 23, 2024
75dcd98
chore: configered argilla version
davidberenstein1957 Jan 23, 2024
4facd32
chore: processed comments code reveiw
davidberenstein1957 Jan 23, 2024
2202e47
chore: updated field selected based on column names
davidberenstein1957 Jan 24, 2024
f63dc6e
chore: added a extra check for dataset_columns is False or []
davidberenstein1957 Jan 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ openai = ["openai >= 1.0.0"]
vllm = ["vllm >= 0.2.1"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
together = ["together"]
argilla = ["argilla >= 1.18.0"]
argilla = ["argilla >= 1.22.0", "sentence-transformers >= 2.0.0"]
tests = ["pytest >= 7.4.0"]
docs = [
"mkdocs-material >= 9.5.0",
Expand Down
49 changes: 48 additions & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@
import importlib_resources
else:
import importlib.resources as importlib_resources

import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union

from jinja2 import Template

from distilabel.tasks.prompt import Prompt
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)

if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)


def get_template(template_name: str) -> str:
Expand Down Expand Up @@ -119,6 +128,44 @@ def to_argilla_record(
" `FeedbackDataset` you will need to implement this method first."
)

def add_vectors_to_argilla_dataset(
self,
dataset: Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"],
vector_strategy: Union[bool, "SentenceTransformersExtractor"],
) -> Union["FeedbackRecord", List["FeedbackRecord"], "FeedbackDataset"]:
if _ARGILLA_AVAILABLE and vector_strategy:
try:
if isinstance(vector_strategy, SentenceTransformersExtractor):
ste = vector_strategy

elif vector_strategy is True:
ste = SentenceTransformersExtractor(
model="en",
show_progress=True,
)
else:
raise ValueError(
"The `vector_strategy` must be either `True` or a `SentenceTransformersExtractor` instance."
)

dataset = ste.update_dataset(dataset=dataset)
except Exception as e:
warnings.warn(
f"An error occurred while adding vectors to the dataset: {e}",
stacklevel=2,
)

elif not _ARGILLA_AVAILABLE and vector_strategy:
warnings.warn(
"An error occurred while adding vectors to the dataset: "
"The `argilla`/`sentence-transformers` packages are not installed or the installed version is not compatible with the"
" required version. If you want to add vectors to your dataset, please run `pip install 'distilabel[vectors]'`.",
stacklevel=2,
)
else:
pass
plaguss marked this conversation as resolved.
Show resolved Hide resolved
return dataset

# Renamed to _to_argilla_record instead of renaming `to_argilla_record` to protected, as that would
# imply more breaking changes.
def _to_argilla_record( # noqa: C901
Expand Down
5 changes: 5 additions & 0 deletions src/distilabel/tasks/critique/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)


@dataclass
Expand Down Expand Up @@ -52,13 +55,15 @@ def to_argilla_dataset(
score_column: str = "score",
critique_column: str = "critique",
score_values: Optional[List[int]] = None,
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
return super().to_argilla_dataset(
dataset_row=dataset_row,
generations_column=generations_column,
ratings_column=score_column,
rationale_column=critique_column,
ratings_values=score_values or [1, 2, 3, 4, 5],
vector_strategy=vector_strategy,
)

def to_argilla_record(
Expand Down
7 changes: 6 additions & 1 deletion src/distilabel/tasks/critique/ultracm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union

from distilabel.tasks.base import get_template
from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput
from distilabel.tasks.prompt import Prompt

if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)

_ULTRACM_TEMPLATE = get_template("ultracm.jinja2")

Expand Down Expand Up @@ -102,11 +105,13 @@ def to_argilla_dataset(
score_column: str = "score",
critique_column: str = "critique",
score_values: Optional[List[int]] = None,
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
return super().to_argilla_dataset(
dataset_row=dataset_row,
generations_column=generations_column,
score_column=score_column,
critique_column=critique_column,
score_values=score_values or [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
vector_strategy=vector_strategy,
)
13 changes: 11 additions & 2 deletions src/distilabel/tasks/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Protocol, Union

from distilabel.tasks.base import Task
from distilabel.utils.argilla import (
infer_fields_from_dataset_row,
model_metadata_from_dataset_row,
Expand All @@ -26,6 +27,9 @@

if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)


class TaskProtocol(Protocol):
Expand All @@ -46,6 +50,7 @@ def to_argilla_dataset(
ratings_column: str = "rating",
rationale_column: str = "rationale",
ratings_values: Optional[List[int]] = None,
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
# First we infer the fields from the input_args_names, but we could also
# create those manually instead using `rg.TextField(...)`
Expand Down Expand Up @@ -116,11 +121,15 @@ def to_argilla_dataset(
)
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.
return rg.FeedbackDataset(
dataset = rg.FeedbackDataset(
fields=fields,
questions=questions,
metadata_properties=metadata_properties, # Note that these are always optional
)
dataset = Task.add_vectors_to_argilla_dataset(
dataset=dataset, vector_strategy=vector_strategy
)
return dataset

def _merge_rationales(
self, rationales: List[str], generations_column: str = "generations"
Expand Down
6 changes: 6 additions & 0 deletions src/distilabel/tasks/preference/ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
List,
Optional,
TypedDict,
Union,
)

from distilabel.tasks.base import get_template
Expand All @@ -30,6 +31,9 @@

if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentence_transformers import (
SentenceTransformersExtractor,
)

_ULTRAFEEDBACK_TEMPLATE = get_template("ultrafeedback.jinja2")

Expand Down Expand Up @@ -132,13 +136,15 @@ def to_argilla_dataset(
ratings_column: str = "rating",
rationale_column: str = "rationale",
ratings_values: Optional[List[int]] = None,
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
return super().to_argilla_dataset(
dataset_row=dataset_row,
generations_column=generations_column,
ratings_column=ratings_column,
rationale_column=rationale_column,
ratings_values=ratings_values or [1, 2, 3, 4, 5],
vector_strategy=vector_strategy,
)

# Override the default `to_argilla_record` method to provide the `ratings_values` of
Expand Down
11 changes: 10 additions & 1 deletion src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
if _ARGILLA_AVAILABLE:
import argilla as rg


if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)


@dataclass
Expand Down Expand Up @@ -154,6 +158,7 @@ def to_argilla_dataset(
self,
dataset_row: Dict[str, Any],
generations_column: Optional[str] = "generations",
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
# First we infer the fields from the input_args_names, but we could also
# create those manually instead using `rg.TextField(...)`
Expand Down Expand Up @@ -201,11 +206,15 @@ def to_argilla_dataset(
)
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.
return rg.FeedbackDataset(
dataset = rg.FeedbackDataset(
fields=fields,
questions=questions,
metadata_properties=metadata_properties, # Note that these are always optional
)
dataset = Task.add_vectors_to_argilla_dataset(
dataset=dataset, vector_strategy=vector_strategy
)
return dataset

def to_argilla_record(self, dataset_row: Dict[str, Any]) -> "FeedbackRecord":
"""Converts a dataset row to an Argilla `FeedbackRecord`."""
Expand Down
21 changes: 17 additions & 4 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import re
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from distilabel.tasks.base import get_template
from distilabel.tasks.base import Task, get_template
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask
from distilabel.utils.argilla import (
Expand All @@ -31,6 +31,9 @@

if TYPE_CHECKING:
from argilla import FeedbackDataset, FeedbackRecord
from argilla.client.feedback.integrations.sentencetransformers import (
SentenceTransformersExtractor,
)

_SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2")

Expand Down Expand Up @@ -103,7 +106,11 @@ def parse_output(self, output: str) -> Dict[str, List[str]]:
pattern = re.compile(r"\d+\.\s*(.*?)\n")
return {"instructions": pattern.findall(output)}

def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
def to_argilla_dataset(
self,
dataset_row: Dict[str, Any],
vector_strategy: Union[bool, "SentenceTransformersExtractor"] = True,
) -> "FeedbackDataset":
# First we infer the fields from the input_args_names, but we could also
# create those manually instead using `rg.TextField(...)`
fields = infer_fields_from_dataset_row(
Expand Down Expand Up @@ -149,11 +156,17 @@ def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
) # type: ignore
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.
return rg.FeedbackDataset(
dataset = rg.FeedbackDataset(
fields=fields,
questions=questions, # type: ignore
metadata_properties=metadata_properties, # Note that these are always optional
)
dataset: (
FeedbackRecord | List[FeedbackRecord] | FeedbackDataset
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved
) = Task.add_vectors_to_argilla_dataset(
dataset=dataset, vector_strategy=vector_strategy
)
return dataset

def to_argilla_record(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/distilabel/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def _check_package_is_available(


_ARGILLA_AVAILABLE = _check_package_is_available(
"argilla", min_version="1.16.0", greater_or_equal=True
"argilla", min_version="1.22.0", greater_or_equal=True
) and _check_package_is_available(
"sentence-transformers", min_version="2.0.0", greater_or_equal=True
)
_OPENAI_AVAILABLE = _check_package_is_available(
"openai", min_version="1.0.0", greater_or_equal=True
Expand Down