Skip to content

Commit

Permalink
Add custom to_argilla_{dataset,record} to SelfInstructTask (#169)
Browse files Browse the repository at this point in the history
* Update `to_argilla_record` return type-hint

* Add TODO for future self

* Add custom `to_argilla_{dataset,record}` in `SelfInstructTask`

* Fix loop over instructions in `SelfInstructTask`

* Remove `print` from `parse_output`
  • Loading branch information
alvarobartt authored Dec 19, 2023
1 parent e06453d commit 62993ac
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def to_argilla_dataset(

def to_argilla_record(
self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any
) -> "FeedbackRecord":
) -> Union["FeedbackRecord", List["FeedbackRecord"]]:
raise NotImplementedError(
"`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla"
" `FeedbackDataset` you will need to implement this method first."
Expand Down
3 changes: 3 additions & 0 deletions src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def to_argilla_record(self, dataset_row: Dict[str, Any]) -> "FeedbackRecord":
arg_value = dataset_row[arg_name]
if isinstance(arg_value, list):
for idx, value in enumerate(arg_value, start=1):
# TODO: value formatting was included here due to some issues
# with `SelfInstructTask` but these list-parsing may not be needed
# anymore.
value = (
value.strip()
if isinstance(value, str)
Expand Down
105 changes: 103 additions & 2 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import warnings
from dataclasses import dataclass
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from distilabel.tasks.base import get_template
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask
from distilabel.utils.argilla import infer_fields_from_dataset_row
from distilabel.utils.imports import _ARGILLA_AVAILABLE

if _ARGILLA_AVAILABLE:
import argilla as rg

if TYPE_CHECKING:
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord

_SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2")

Expand Down Expand Up @@ -79,6 +90,96 @@ def generate_prompt(self, input: str) -> Prompt:
formatted_prompt=self.template.render(**render_kwargs),
)

@property
def output_args_names(self) -> List[str]:
return ["instructions"]

def parse_output(self, output: str) -> Dict[str, List[str]]:
"""Parses the output of the model into the desired format."""
return {"generations": output.split("\n")}
pattern = re.compile(r"\d+\.\s+(.*?)\n")
return {"instructions": pattern.findall(output)}

def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
# First we infer the fields from the input_args_names, but we could also
# create those manually instead using `rg.TextField(...)`
fields = infer_fields_from_dataset_row(
field_names=self.input_args_names,
dataset_row=dataset_row,
)
# Once the input fields have been defined, then we also include the instruction
# field which will be fulfilled with each of the instructions generated.
fields.append(rg.TextField(name="instruction", title="instruction")) # type: ignore
# Then we add a default `RatingQuestion` which asks the users to provide a
# rating for each of the generations, differing from the scenario where the inputs
# are the fields and the outputs the ones used to formulate the quesstions. So on,
# in this scenario we won't have suggestions, as the questions will be related to the
# combination of inputs and outputs.
questions = [
rg.RatingQuestion( # type: ignore
name="instruction-rating",
title="How would you rate the generated instruction?",
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
)
]
# Finally, we define some metadata properties that can be potentially used
# while exploring the dataset within Argilla to get more insights on the data.
metadata_properties = []
for arg_name in self.input_args_names:
if isinstance(dataset_row[arg_name], list):
for idx in range(1, len(dataset_row[arg_name]) + 1):
metadata_properties.append(
rg.IntegerMetadataProperty(name=f"length-{arg_name}-{idx}") # type: ignore
)
elif isinstance(dataset_row[arg_name], str):
metadata_properties.append(
rg.IntegerMetadataProperty(name=f"length-{arg_name}") # type: ignore
)
else:
warnings.warn(
f"Unsupported input type ({type(dataset_row[arg_name])}), skipping...",
UserWarning,
stacklevel=2,
)
metadata_properties.append(
rg.IntegerMetadataProperty(name="length-instruction") # type: ignore
) # type: ignore
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.
return rg.FeedbackDataset(
fields=fields,
questions=questions, # type: ignore
metadata_properties=metadata_properties, # Note that these are always optional
)

def to_argilla_record(
self,
dataset_row: Dict[str, Any],
instructions_column: Optional[str] = "instructions",
) -> List["FeedbackRecord"]:
"""Converts a dataset row to a list of Argilla `FeedbackRecord`s."""
records = []
for instructions in dataset_row[instructions_column]: # type: ignore
for instruction in instructions:
fields, metadata = {}, {}
for arg_name in self.input_args_names:
arg_value = dataset_row[arg_name]
if isinstance(arg_value, list):
for idx, value in enumerate(arg_value, start=1):
value = value.strip() if isinstance(value, str) else ""
fields[f"{arg_name}-{idx}"] = value
if value is not None:
metadata[f"length-{arg_name}-{idx}"] = len(value)
elif isinstance(arg_value, str):
fields[arg_name] = arg_value.strip() if arg_value else ""
if arg_value is not None:
metadata[f"length-{arg_name}"] = len(arg_value.strip())
else:
warnings.warn(
f"Unsupported input type ({type(arg_value)}), skipping...",
UserWarning,
stacklevel=2,
)
fields["instruction"] = instruction
metadata["length-instruction"] = len(instruction)
records.append(rg.FeedbackRecord(fields=fields, metadata=metadata))
return records

0 comments on commit 62993ac

Please sign in to comment.