From 3244c05e353f9d00cff6c88102fe82d255fe8ea3 Mon Sep 17 00:00:00 2001 From: zye1996 Date: Mon, 30 Sep 2024 03:25:34 -0700 Subject: [PATCH] Fix writing `distilabel_metadata` column when `LLM` error (#1003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix metadata writeout when llm error * linter reformat --------- Co-authored-by: Gabriel Martín Blázquez --- src/distilabel/pipeline/step_wrapper.py | 8 +------- src/distilabel/pipeline/write_buffer.py | 9 ++++++--- src/distilabel/steps/base.py | 14 ++++++++++++++ src/distilabel/steps/tasks/base.py | 24 +++++++++++++++++++++++- 4 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 661f99b1f7..2ddfb10daa 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -277,13 +277,7 @@ def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]: Args: batch: The batch to impute. """ - result = [] - for row in batch.data[0]: - data = row.copy() - for output in self.step.outputs: - data[output] = None - result.append(data) - return result + return self.step.impute_step_outputs(batch.data[0]) def _send_batch(self, batch: _Batch) -> None: """Sends a batch to the `output_queue`.""" diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py index a71ffdd9b2..e53d6faebf 100644 --- a/src/distilabel/pipeline/write_buffer.py +++ b/src/distilabel/pipeline/write_buffer.py @@ -130,9 +130,12 @@ def _write(self, step_name: str) -> None: self._buffer_last_schema[step_name] = table.schema else: if not last_schema.equals(table.schema): - new_schema = pa.unify_schemas([last_schema, table.schema]) - self._buffer_last_schema[step_name] = new_schema - table = table.cast(new_schema) + if set(last_schema.names) == set(table.schema.names): + table = table.select(last_schema.names) + else: + new_schema = pa.unify_schemas([last_schema, table.schema]) + self._buffer_last_schema[step_name] = new_schema + table = table.cast(new_schema) next_file_number = self._buffers_last_file[step_name] self._buffers_last_file[step_name] = next_file_number + 1 diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index cc0c0b2e1a..a3040133a0 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -582,6 +582,20 @@ def save_artifact( ) write_json(filename=metadata_path, data=metadata or {}) + def impute_step_outputs( + self, step_output: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Imputes the output columns of the step that are not present in the step output. + """ + result = [] + for row in step_output: + data = row.copy() + for output in self.outputs: + data[output] = None + result.append(data) + return result + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: dump = super()._model_dump(obj, **kwargs) dump["runtime_parameters_info"] = self.get_runtime_parameters_info() diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index a0afb74c32..33b0330466 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -117,6 +117,28 @@ def unload(self) -> None: self._logger.debug("Executing task unload logic.") self.llm.unload() + @override + def impute_step_outputs( + self, step_output: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Imputes the outputs of the task in case the LLM failed to generate a response. + """ + result = [] + for row in step_output: + data = row.copy() + for output in self.outputs: + data[output] = None + data = self._maybe_add_raw_input_output( + data, + None, + None, + add_raw_output=self.add_raw_output, + add_raw_input=self.add_raw_input, + ) + result.append(data) + return result + @abstractmethod def format_output( self, @@ -201,7 +223,7 @@ def _maybe_add_raw_input_output( if add_raw_output: meta[f"raw_output_{self.name}"] = raw_output if add_raw_input: - meta[f"raw_input_{self.name}"] = self.format_input(input) + meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None if meta: output[DISTILABEL_METADATA_KEY] = meta