Skip to content

Commit

Permalink
Fix loop over instructions in SelfInstructTask
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarobartt committed Dec 19, 2023
1 parent 8cbbcb4 commit 46b5210
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def output_args_names(self) -> List[str]:
def parse_output(self, output: str) -> Dict[str, List[str]]:
"""Parses the output of the model into the desired format."""
pattern = re.compile(r"\d+\.\s+(.*?)\n")
print(pattern.findall(output))
return {"instructions": pattern.findall(output)}

def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
Expand Down Expand Up @@ -141,7 +142,7 @@ def to_argilla_dataset(self, dataset_row: Dict[str, Any]) -> "FeedbackDataset":
stacklevel=2,
)
metadata_properties.append(
rg.IntegerMetadataProperty(name="length-instruction")
rg.IntegerMetadataProperty(name="length-instruction") # type: ignore
) # type: ignore
# Then we just return the `FeedbackDataset` with the fields, questions, and metadata properties
# defined above.
Expand All @@ -158,27 +159,28 @@ def to_argilla_record(
) -> List["FeedbackRecord"]:
"""Converts a dataset row to a list of Argilla `FeedbackRecord`s."""
records = []
for instruction in dataset_row[instructions_column]: # type: ignore
fields, metadata = {}, {}
for arg_name in self.input_args_names:
arg_value = dataset_row[arg_name]
if isinstance(arg_value, list):
for idx, value in enumerate(arg_value, start=1):
value = value.strip() if isinstance(value, str) else ""
fields[f"{arg_name}-{idx}"] = value
if value is not None:
metadata[f"length-{arg_name}-{idx}"] = len(value)
elif isinstance(arg_value, str):
fields[arg_name] = arg_value.strip() if arg_value else ""
if arg_value is not None:
metadata[f"length-{arg_name}"] = len(arg_value.strip())
else:
warnings.warn(
f"Unsupported input type ({type(arg_value)}), skipping...",
UserWarning,
stacklevel=2,
)
fields["instruction"] = instruction
metadata["length-instruction"] = len(instruction)
records.append(rg.FeedbackRecord(fields=fields, metadata=metadata))
for instructions in dataset_row[instructions_column]: # type: ignore
for instruction in instructions:
fields, metadata = {}, {}
for arg_name in self.input_args_names:
arg_value = dataset_row[arg_name]
if isinstance(arg_value, list):
for idx, value in enumerate(arg_value, start=1):
value = value.strip() if isinstance(value, str) else ""
fields[f"{arg_name}-{idx}"] = value
if value is not None:
metadata[f"length-{arg_name}-{idx}"] = len(value)
elif isinstance(arg_value, str):
fields[arg_name] = arg_value.strip() if arg_value else ""
if arg_value is not None:
metadata[f"length-{arg_name}"] = len(arg_value.strip())
else:
warnings.warn(
f"Unsupported input type ({type(arg_value)}), skipping...",
UserWarning,
stacklevel=2,
)
fields["instruction"] = instruction
metadata["length-instruction"] = len(instruction)
records.append(rg.FeedbackRecord(fields=fields, metadata=metadata))
return records

0 comments on commit 46b5210

Please sign in to comment.