Skip to content

Commit

Permalink
Fix writing distilabel_metadata column when LLM error (#1003)
Browse files Browse the repository at this point in the history
* fix metadata writeout when llm error

* linter reformat

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
zye1996 and gabrielmbmb authored Sep 30, 2024
1 parent a49242d commit 3244c05
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
8 changes: 1 addition & 7 deletions src/distilabel/pipeline/step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,7 @@ def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]:
Args:
batch: The batch to impute.
"""
result = []
for row in batch.data[0]:
data = row.copy()
for output in self.step.outputs:
data[output] = None
result.append(data)
return result
return self.step.impute_step_outputs(batch.data[0])

def _send_batch(self, batch: _Batch) -> None:
"""Sends a batch to the `output_queue`."""
Expand Down
9 changes: 6 additions & 3 deletions src/distilabel/pipeline/write_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ def _write(self, step_name: str) -> None:
self._buffer_last_schema[step_name] = table.schema
else:
if not last_schema.equals(table.schema):
new_schema = pa.unify_schemas([last_schema, table.schema])
self._buffer_last_schema[step_name] = new_schema
table = table.cast(new_schema)
if set(last_schema.names) == set(table.schema.names):
table = table.select(last_schema.names)
else:
new_schema = pa.unify_schemas([last_schema, table.schema])
self._buffer_last_schema[step_name] = new_schema
table = table.cast(new_schema)

next_file_number = self._buffers_last_file[step_name]
self._buffers_last_file[step_name] = next_file_number + 1
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,20 @@ def save_artifact(
)
write_json(filename=metadata_path, data=metadata or {})

def impute_step_outputs(
self, step_output: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Imputes the output columns of the step that are not present in the step output.
"""
result = []
for row in step_output:
data = row.copy()
for output in self.outputs:
data[output] = None
result.append(data)
return result

def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
dump = super()._model_dump(obj, **kwargs)
dump["runtime_parameters_info"] = self.get_runtime_parameters_info()
Expand Down
24 changes: 23 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ def unload(self) -> None:
self._logger.debug("Executing task unload logic.")
self.llm.unload()

@override
def impute_step_outputs(
self, step_output: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Imputes the outputs of the task in case the LLM failed to generate a response.
"""
result = []
for row in step_output:
data = row.copy()
for output in self.outputs:
data[output] = None
data = self._maybe_add_raw_input_output(
data,
None,
None,
add_raw_output=self.add_raw_output,
add_raw_input=self.add_raw_input,
)
result.append(data)
return result

@abstractmethod
def format_output(
self,
Expand Down Expand Up @@ -201,7 +223,7 @@ def _maybe_add_raw_input_output(
if add_raw_output:
meta[f"raw_output_{self.name}"] = raw_output
if add_raw_input:
meta[f"raw_input_{self.name}"] = self.format_input(input)
meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None
if meta:
output[DISTILABEL_METADATA_KEY] = meta

Expand Down

0 comments on commit 3244c05

Please sign in to comment.