diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 1c9ec3f5a8..251c25d214 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -113,7 +113,7 @@ def to_argilla_dataset( def to_argilla_record( self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any - ) -> "FeedbackRecord": + ) -> Union["FeedbackRecord", List["FeedbackRecord"]]: raise NotImplementedError( "`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla" " `FeedbackDataset` you will need to implement this method first." diff --git a/src/distilabel/tasks/text_generation/base.py b/src/distilabel/tasks/text_generation/base.py index 6a069773b7..2d4c8464c6 100644 --- a/src/distilabel/tasks/text_generation/base.py +++ b/src/distilabel/tasks/text_generation/base.py @@ -212,6 +212,9 @@ def to_argilla_record(self, dataset_row: Dict[str, Any]) -> "FeedbackRecord": arg_value = dataset_row[arg_name] if isinstance(arg_value, list): for idx, value in enumerate(arg_value, start=1): + # TODO: value formatting was included here due to some issues + # with `SelfInstructTask` but these list-parsing may not be needed + # anymore. value = ( value.strip() if isinstance(value, str) diff --git a/src/distilabel/tasks/text_generation/self_instruct.py b/src/distilabel/tasks/text_generation/self_instruct.py index 349682d59f..f787d46726 100644 --- a/src/distilabel/tasks/text_generation/self_instruct.py +++ b/src/distilabel/tasks/text_generation/self_instruct.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +import warnings from dataclasses import dataclass -from typing import Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional from distilabel.tasks.base import get_template from distilabel.tasks.prompt import Prompt from distilabel.tasks.text_generation.base import TextGenerationTask +from distilabel.utils.argilla import infer_fields_from_dataset_row +from distilabel.utils.imports import _ARGILLA_AVAILABLE + +if _ARGILLA_AVAILABLE: + import argilla as rg + +if TYPE_CHECKING: + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset + from argilla.client.feedback.schemas.records import FeedbackRecord _SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2") @@ -79,6 +90,96 @@ def generate_prompt(self, input: str) -> Prompt: formatted_prompt=self.template.render(**render_kwargs), ) + @property + def output_args_names(self) -> List[str]: + return ["instructions"] + def parse_output(self, output: str) -> Dict[str, List[str]]: """Parses the output of the model into the desired format.""" - return {"generations": output.split("\n")} + pattern = re.compile(r"\d+\.\s+(.*?)\n") + return {"instructions": pattern.findall(output)} + + def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset": + # First we infer the fields from the input_args_names, but we could also + # create those manually instead using `rg.TextField(...)` + fields = infer_fields_from_dataset_row( + field_names=self.input_args_names, + dataset_row=dataset_row, + ) + # Once the input fields have been defined, then we also include the instruction + # field which will be fulfilled with each of the instructions generated. + fields.append(rg.TextField(name="instruction", title="instruction")) # type: ignore + # Then we add a default `RatingQuestion` which asks the users to provide a + # rating for each of the generations, differing from the scenario where the inputs + # are the fields and the outputs the ones used to formulate the quesstions. So on, + # in this scenario we won't have suggestions, as the questions will be related to the + # combination of inputs and outputs. + questions = [ + rg.RatingQuestion( # type: ignore + name="instruction-rating", + title="How would you rate the generated instruction?", + values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ) + ] + # Finally, we define some metadata properties that can be potentially used + # while exploring the dataset within Argilla to get more insights on the data. + metadata_properties = [] + for arg_name in self.input_args_names: + if isinstance(dataset_row[arg_name], list): + for idx in range(1, len(dataset_row[arg_name]) + 1): + metadata_properties.append( + rg.IntegerMetadataProperty(name=f"length-{arg_name}-{idx}") # type: ignore + ) + elif isinstance(dataset_row[arg_name], str): + metadata_properties.append( + rg.IntegerMetadataProperty(name=f"length-{arg_name}") # type: ignore + ) + else: + warnings.warn( + f"Unsupported input type ({type(dataset_row[arg_name])}), skipping...", + UserWarning, + stacklevel=2, + ) + metadata_properties.append( + rg.IntegerMetadataProperty(name="length-instruction") # type: ignore + ) # type: ignore + # Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties + # defined above. + return rg.FeedbackDataset( + fields=fields, + questions=questions, # type: ignore + metadata_properties=metadata_properties, # Note that these are always optional + ) + + def to_argilla_record( + self, + dataset_row: Dict[str, Any], + instructions_column: Optional[str] = "instructions", + ) -> List["FeedbackRecord"]: + """Converts a dataset row to a list of Argilla `FeedbackRecord`s.""" + records = [] + for instructions in dataset_row[instructions_column]: # type: ignore + for instruction in instructions: + fields, metadata = {}, {} + for arg_name in self.input_args_names: + arg_value = dataset_row[arg_name] + if isinstance(arg_value, list): + for idx, value in enumerate(arg_value, start=1): + value = value.strip() if isinstance(value, str) else "" + fields[f"{arg_name}-{idx}"] = value + if value is not None: + metadata[f"length-{arg_name}-{idx}"] = len(value) + elif isinstance(arg_value, str): + fields[arg_name] = arg_value.strip() if arg_value else "" + if arg_value is not None: + metadata[f"length-{arg_name}"] = len(arg_value.strip()) + else: + warnings.warn( + f"Unsupported input type ({type(arg_value)}), skipping...", + UserWarning, + stacklevel=2, + ) + fields["instruction"] = instruction + metadata["length-instruction"] = len(instruction) + records.append(rg.FeedbackRecord(fields=fields, metadata=metadata)) + return records