diff --git a/pr-1066/404.html b/pr-1066/404.html index 7bd7fcfc33..85a0b8a1f0 100644 --- a/pr-1066/404.html +++ b/pr-1066/404.html @@ -2597,6 +2597,27 @@ +
class OpenAILLM(AsyncLLM):
"""OpenAI LLM implementation running the async API client.
Attributes:
@@ -22582,396 +22630,402 @@
"top_p": top_p,
"stop": stop,
}
-
- if response_format is not None:
- kwargs["response_format"] = response_format
-
- if structured_output:
- kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
+ # Check if it's a vision generation task, in that case "stop" cannot be used or raises
+ # an error in the API.
+ if isinstance(
+ [row for row in input if row["role"] == "user"][0]["content"], list
+ ):
+ kwargs.pop("stop")
- completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
- if structured_output:
- return prepare_output(
- [completion.model_dump_json()],
- **self._get_llm_statistics(completion._raw_response),
- )
-
- return self._generations_from_openai_completion(completion)
-
- def _generations_from_openai_completion(
- self, completion: "OpenAIChatCompletion"
- ) -> "GenerateOutput":
- """Get the generations from the OpenAI Chat Completion object.
-
- Args:
- completion: the completion object to get the generations from.
-
- Returns:
- A list of strings containing the generated responses for the input.
- """
- generations = []
- for choice in completion.choices:
- if (content := choice.message.content) is None:
- self._logger.warning( # type: ignore
- f"Received no response using OpenAI client (model: '{self.model}')."
- f" Finish reason was: {choice.finish_reason}"
- )
- generations.append(content)
-
- return prepare_output(generations, **self._get_llm_statistics(completion))
-
- def offline_batch_generate(
- self,
- inputs: Union[List["FormattedInput"], None] = None,
- num_generations: int = 1,
- max_new_tokens: int = 128,
- frequency_penalty: float = 0.0,
- presence_penalty: float = 0.0,
- temperature: float = 1.0,
- top_p: float = 1.0,
- stop: Optional[Union[str, List[str]]] = None,
- response_format: Optional[str] = None,
- **kwargs: Any,
- ) -> List["GenerateOutput"]:
- """Uses the OpenAI batch API to generate `num_generations` responses for the given
- inputs.
-
- Args:
- inputs: a list of inputs in chat format to generate responses for.
- num_generations: the number of generations to create per input. Defaults to
- `1`.
- max_new_tokens: the maximum number of new tokens that the model will generate.
- Defaults to `128`.
- frequency_penalty: the repetition penalty to use for the generation. Defaults
- to `0.0`.
- presence_penalty: the presence penalty to use for the generation. Defaults to
- `0.0`.
- temperature: the temperature to use for the generation. Defaults to `0.1`.
- top_p: the top-p value to use for the generation. Defaults to `1.0`.
- stop: a string or a list of strings to use as a stop sequence for the generation.
- Defaults to `None`.
- response_format: the format of the response to return. Must be one of
- "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
- for more information on how to use the JSON model from OpenAI. Defaults to `text`.
-
- Returns:
- A list of lists of strings containing the generated responses for each input
- in `inputs`.
-
- Raises:
- DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
- is not finished yet.
- ValueError: if no job IDs were found to retrieve the results from.
- """
- if self.jobs_ids:
- return self._check_and_get_batch_results()
-
- if inputs:
- self.jobs_ids = self._create_jobs(
- inputs=inputs,
- **{
- "model": self.model,
- "max_tokens": max_new_tokens,
- "n": num_generations,
- "frequency_penalty": frequency_penalty,
- "presence_penalty": presence_penalty,
- "temperature": temperature,
- "top_p": top_p,
- "stop": stop,
- "response_format": response_format,
- },
- )
- raise DistilabelOfflineBatchGenerationNotFinishedException(
- jobs_ids=self.jobs_ids
- )
-
- raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
-
- def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
- """Checks the status of the batch jobs and retrieves the results from the OpenAI
- Batch API.
+ if response_format is not None:
+ kwargs["response_format"] = response_format
+
+ if structured_output:
+ kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
+
+ completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
+ if structured_output:
+ return prepare_output(
+ [completion.model_dump_json()],
+ **self._get_llm_statistics(completion._raw_response),
+ )
+
+ return self._generations_from_openai_completion(completion)
+
+ def _generations_from_openai_completion(
+ self, completion: "OpenAIChatCompletion"
+ ) -> "GenerateOutput":
+ """Get the generations from the OpenAI Chat Completion object.
+
+ Args:
+ completion: the completion object to get the generations from.
+
+ Returns:
+ A list of strings containing the generated responses for the input.
+ """
+ generations = []
+ for choice in completion.choices:
+ if (content := choice.message.content) is None:
+ self._logger.warning( # type: ignore
+ f"Received no response using OpenAI client (model: '{self.model}')."
+ f" Finish reason was: {choice.finish_reason}"
+ )
+ generations.append(content)
+
+ return prepare_output(generations, **self._get_llm_statistics(completion))
+
+ def offline_batch_generate(
+ self,
+ inputs: Union[List["FormattedInput"], None] = None,
+ num_generations: int = 1,
+ max_new_tokens: int = 128,
+ frequency_penalty: float = 0.0,
+ presence_penalty: float = 0.0,
+ temperature: float = 1.0,
+ top_p: float = 1.0,
+ stop: Optional[Union[str, List[str]]] = None,
+ response_format: Optional[str] = None,
+ **kwargs: Any,
+ ) -> List["GenerateOutput"]:
+ """Uses the OpenAI batch API to generate `num_generations` responses for the given
+ inputs.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for.
+ num_generations: the number of generations to create per input. Defaults to
+ `1`.
+ max_new_tokens: the maximum number of new tokens that the model will generate.
+ Defaults to `128`.
+ frequency_penalty: the repetition penalty to use for the generation. Defaults
+ to `0.0`.
+ presence_penalty: the presence penalty to use for the generation. Defaults to
+ `0.0`.
+ temperature: the temperature to use for the generation. Defaults to `0.1`.
+ top_p: the top-p value to use for the generation. Defaults to `1.0`.
+ stop: a string or a list of strings to use as a stop sequence for the generation.
+ Defaults to `None`.
+ response_format: the format of the response to return. Must be one of
+ "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
+ for more information on how to use the JSON model from OpenAI. Defaults to `text`.
+
+ Returns:
+ A list of lists of strings containing the generated responses for each input
+ in `inputs`.
+
+ Raises:
+ DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
+ is not finished yet.
+ ValueError: if no job IDs were found to retrieve the results from.
+ """
+ if self.jobs_ids:
+ return self._check_and_get_batch_results()
+
+ if inputs:
+ self.jobs_ids = self._create_jobs(
+ inputs=inputs,
+ **{
+ "model": self.model,
+ "max_tokens": max_new_tokens,
+ "n": num_generations,
+ "frequency_penalty": frequency_penalty,
+ "presence_penalty": presence_penalty,
+ "temperature": temperature,
+ "top_p": top_p,
+ "stop": stop,
+ "response_format": response_format,
+ },
+ )
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids
+ )
- Returns:
- A list of lists of strings containing the generated responses for each input.
-
- Raises:
- ValueError: if no job IDs were found to retrieve the results from.
- DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
- is not finished yet.
- RuntimeError: if the only batch job found failed.
- """
- if not self.jobs_ids:
- raise ValueError("No job IDs were found to retrieve the results from.")
-
- outputs = []
- for batch_id in self.jobs_ids:
- batch = self._get_openai_batch(batch_id)
-
- if batch.status in ("validating", "in_progress", "finalizing"):
- raise DistilabelOfflineBatchGenerationNotFinishedException(
- jobs_ids=self.jobs_ids
- )
-
- if batch.status in ("failed", "expired", "cancelled", "cancelling"):
- self._logger.error( # type: ignore
- f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
- )
- if len(self.jobs_ids) == 1:
- self.jobs_ids = None
- raise RuntimeError(
- f"The only OpenAI API Batch that was created with ID '{batch_id}'"
- f" failed with status '{batch.status}'."
- )
-
- continue
-
- outputs.extend(self._retrieve_batch_results(batch))
-
- # sort by `custom_id` to return the results in the same order as the inputs
- outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
- return [self._parse_output(output) for output in outputs]
+ raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
+
+ def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
+ """Checks the status of the batch jobs and retrieves the results from the OpenAI
+ Batch API.
+
+ Returns:
+ A list of lists of strings containing the generated responses for each input.
+
+ Raises:
+ ValueError: if no job IDs were found to retrieve the results from.
+ DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
+ is not finished yet.
+ RuntimeError: if the only batch job found failed.
+ """
+ if not self.jobs_ids:
+ raise ValueError("No job IDs were found to retrieve the results from.")
+
+ outputs = []
+ for batch_id in self.jobs_ids:
+ batch = self._get_openai_batch(batch_id)
+
+ if batch.status in ("validating", "in_progress", "finalizing"):
+ raise DistilabelOfflineBatchGenerationNotFinishedException(
+ jobs_ids=self.jobs_ids
+ )
+
+ if batch.status in ("failed", "expired", "cancelled", "cancelling"):
+ self._logger.error( # type: ignore
+ f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
+ )
+ if len(self.jobs_ids) == 1:
+ self.jobs_ids = None
+ raise RuntimeError(
+ f"The only OpenAI API Batch that was created with ID '{batch_id}'"
+ f" failed with status '{batch.status}'."
+ )
+
+ continue
- def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
- """Parses the output from the OpenAI Batch API into a list of strings.
-
- Args:
- output: the output to parse.
+ outputs.extend(self._retrieve_batch_results(batch))
+
+ # sort by `custom_id` to return the results in the same order as the inputs
+ outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
+ return [self._parse_output(output) for output in outputs]
- Returns:
- A list of strings containing the generated responses for the input.
- """
- from openai.types.chat import ChatCompletion as OpenAIChatCompletion
-
- if "response" not in output:
- return []
-
- if output["response"]["status_code"] != 200:
- return []
+ def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
+ """Parses the output from the OpenAI Batch API into a list of strings.
+
+ Args:
+ output: the output to parse.
+
+ Returns:
+ A list of strings containing the generated responses for the input.
+ """
+ from openai.types.chat import ChatCompletion as OpenAIChatCompletion
- return self._generations_from_openai_completion(
- OpenAIChatCompletion(**output["response"]["body"])
- )
-
- def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
- """Gets a batch from the OpenAI Batch API.
-
- Args:
- batch_id: the ID of the batch to retrieve.
+ if "response" not in output:
+ return []
+
+ if output["response"]["status_code"] != 200:
+ return []
+
+ return self._generations_from_openai_completion(
+ OpenAIChatCompletion(**output["response"]["body"])
+ )
- Returns:
- The batch retrieved from the OpenAI Batch API.
+ def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
+ """Gets a batch from the OpenAI Batch API.
- Raises:
- openai.OpenAIError: if there was an error while retrieving the batch from the
- OpenAI Batch API.
- """
- import openai
+ Args:
+ batch_id: the ID of the batch to retrieve.
+
+ Returns:
+ The batch retrieved from the OpenAI Batch API.
- try:
- return self._client.batches.retrieve(batch_id)
- except openai.OpenAIError as e:
- self._logger.error( # type: ignore
- f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
- )
- raise e
-
- def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
- """Retrieves the results of a batch from its output file, parsing the JSONL content
- into a list of dictionaries.
-
- Args:
- batch: the batch to retrieve the results from.
-
- Returns:
- A list of dictionaries containing the results of the batch.
+ Raises:
+ openai.OpenAIError: if there was an error while retrieving the batch from the
+ OpenAI Batch API.
+ """
+ import openai
+
+ try:
+ return self._client.batches.retrieve(batch_id)
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
+ )
+ raise e
+
+ def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
+ """Retrieves the results of a batch from its output file, parsing the JSONL content
+ into a list of dictionaries.
- Raises:
- AssertionError: if no output file ID was found in the batch.
- """
- import openai
-
- assert batch.output_file_id, "No output file ID was found in the batch."
-
- try:
- file_response = self._client.files.content(batch.output_file_id)
- return [orjson.loads(line) for line in file_response.text.splitlines()]
- except openai.OpenAIError as e:
- self._logger.error( # type: ignore
- f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
- )
- return []
-
- def _create_jobs(
- self, inputs: List["FormattedInput"], **kwargs: Any
- ) -> Tuple[str, ...]:
- """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.
-
- Args:
- inputs: a list of inputs in chat format to generate responses for.
- kwargs: the keyword arguments to use for the generation.
-
- Returns:
- A list of job IDs created in the OpenAI Batch API.
- """
- batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
- jobs = []
- for batch_input_file in batch_input_files:
- if batch := self._create_batch_api_job(batch_input_file):
- jobs.append(batch.id)
- return tuple(jobs)
-
- def _create_batch_api_job(
- self, batch_input_file: "OpenAIFileObject"
- ) -> Union["OpenAIBatch", None]:
- """Creates a job in the OpenAI Batch API to generate responses for the given input
- file.
+ Args:
+ batch: the batch to retrieve the results from.
+
+ Returns:
+ A list of dictionaries containing the results of the batch.
+
+ Raises:
+ AssertionError: if no output file ID was found in the batch.
+ """
+ import openai
+
+ assert batch.output_file_id, "No output file ID was found in the batch."
+
+ try:
+ file_response = self._client.files.content(batch.output_file_id)
+ return [orjson.loads(line) for line in file_response.text.splitlines()]
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
+ )
+ return []
+
+ def _create_jobs(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> Tuple[str, ...]:
+ """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ A list of job IDs created in the OpenAI Batch API.
+ """
+ batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
+ jobs = []
+ for batch_input_file in batch_input_files:
+ if batch := self._create_batch_api_job(batch_input_file):
+ jobs.append(batch.id)
+ return tuple(jobs)
- Args:
- batch_input_file: the input file to generate responses for.
-
- Returns:
- The batch job created in the OpenAI Batch API.
- """
- import openai
-
- metadata = {"description": "distilabel"}
-
- if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
- metadata["distilabel_pipeline_name"] = distilabel_pipeline_name
-
- if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
- metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id
+ def _create_batch_api_job(
+ self, batch_input_file: "OpenAIFileObject"
+ ) -> Union["OpenAIBatch", None]:
+ """Creates a job in the OpenAI Batch API to generate responses for the given input
+ file.
+
+ Args:
+ batch_input_file: the input file to generate responses for.
+
+ Returns:
+ The batch job created in the OpenAI Batch API.
+ """
+ import openai
+
+ metadata = {"description": "distilabel"}
- batch = None
- try:
- batch = self._client.batches.create(
- completion_window="24h",
- endpoint="/v1/chat/completions",
- input_file_id=batch_input_file.id,
- metadata=metadata,
- )
- except openai.OpenAIError as e:
- self._logger.error( # type: ignore
- f"Error while creating OpenAI Batch API job for file with ID"
- f" '{batch_input_file.id}': {e}."
- )
- raise e
- return batch
-
- def _create_batch_files(
- self, inputs: List["FormattedInput"], **kwargs: Any
- ) -> List["OpenAIFileObject"]:
- """Creates the necessary input files for the batch API to generate responses. The
- maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
- need to split the inputs into multiple files if necessary.
-
- More information: https://platform.openai.com/docs/api-reference/files/create
-
- Args:
- inputs: a list of inputs in chat format to generate responses for, optionally
- including structured output.
- kwargs: the keyword arguments to use for the generation.
-
- Returns:
- The list of file objects created for the OpenAI Batch API.
-
- Raises:
- openai.OpenAIError: if there was an error while creating the batch input file
- in the OpenAI Batch API.
- """
- import openai
+ if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
+ metadata["distilabel_pipeline_name"] = distilabel_pipeline_name
+
+ if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
+ metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id
+
+ batch = None
+ try:
+ batch = self._client.batches.create(
+ completion_window="24h",
+ endpoint="/v1/chat/completions",
+ input_file_id=batch_input_file.id,
+ metadata=metadata,
+ )
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while creating OpenAI Batch API job for file with ID"
+ f" '{batch_input_file.id}': {e}."
+ )
+ raise e
+ return batch
+
+ def _create_batch_files(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> List["OpenAIFileObject"]:
+ """Creates the necessary input files for the batch API to generate responses. The
+ maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
+ need to split the inputs into multiple files if necessary.
+
+ More information: https://platform.openai.com/docs/api-reference/files/create
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ The list of file objects created for the OpenAI Batch API.
- files = []
- for file_no, buffer in enumerate(
- self._create_jsonl_buffers(inputs=inputs, **kwargs)
- ):
- try:
- # TODO: add distilabel pipeline name and id
- batch_input_file = self._client.files.create(
- file=(self._name_for_openai_files(file_no), buffer),
- purpose="batch",
- )
- files.append(batch_input_file)
- except openai.OpenAIError as e:
- self._logger.error( # type: ignore
- f"Error while creating OpenAI batch input file: {e}"
- )
- raise e
- return files
-
- def _create_jsonl_buffers(
- self, inputs: List["FormattedInput"], **kwargs: Any
- ) -> Generator[io.BytesIO, None, None]:
- """Creates a generator of buffers containing the JSONL formatted inputs to be
- used by the OpenAI Batch API. The buffers created are of size 100MB or less.
+ Raises:
+ openai.OpenAIError: if there was an error while creating the batch input file
+ in the OpenAI Batch API.
+ """
+ import openai
+
+ files = []
+ for file_no, buffer in enumerate(
+ self._create_jsonl_buffers(inputs=inputs, **kwargs)
+ ):
+ try:
+ # TODO: add distilabel pipeline name and id
+ batch_input_file = self._client.files.create(
+ file=(self._name_for_openai_files(file_no), buffer),
+ purpose="batch",
+ )
+ files.append(batch_input_file)
+ except openai.OpenAIError as e:
+ self._logger.error( # type: ignore
+ f"Error while creating OpenAI batch input file: {e}"
+ )
+ raise e
+ return files
- Args:
- inputs: a list of inputs in chat format to generate responses for, optionally
- including structured output.
- kwargs: the keyword arguments to use for the generation.
-
- Yields:
- A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
- API.
- """
- buffer = io.BytesIO()
- buffer_current_size = 0
- for i, input in enumerate(inputs):
- # We create the smallest `custom_id` so we don't increase the size of the file
- # to much, but we can still sort the results with the order of the inputs.
- row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
- row_size = len(row)
- if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
- buffer.seek(0)
- yield buffer
- buffer = io.BytesIO()
- buffer_current_size = 0
- buffer.write(row)
- buffer_current_size += row_size
-
- if buffer_current_size > 0:
- buffer.seek(0)
- yield buffer
-
- def _create_jsonl_row(
- self, input: "FormattedInput", custom_id: str, **kwargs: Any
- ) -> bytes:
- """Creates a JSONL formatted row to be used by the OpenAI Batch API.
-
- Args:
- input: a list of inputs in chat format to generate responses for, optionally
- including structured output.
- custom_id: a custom ID to use for the row.
- kwargs: the keyword arguments to use for the generation.
+ def _create_jsonl_buffers(
+ self, inputs: List["FormattedInput"], **kwargs: Any
+ ) -> Generator[io.BytesIO, None, None]:
+ """Creates a generator of buffers containing the JSONL formatted inputs to be
+ used by the OpenAI Batch API. The buffers created are of size 100MB or less.
+
+ Args:
+ inputs: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ kwargs: the keyword arguments to use for the generation.
+
+ Yields:
+ A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
+ API.
+ """
+ buffer = io.BytesIO()
+ buffer_current_size = 0
+ for i, input in enumerate(inputs):
+ # We create the smallest `custom_id` so we don't increase the size of the file
+ # to much, but we can still sort the results with the order of the inputs.
+ row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
+ row_size = len(row)
+ if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
+ buffer.seek(0)
+ yield buffer
+ buffer = io.BytesIO()
+ buffer_current_size = 0
+ buffer.write(row)
+ buffer_current_size += row_size
+
+ if buffer_current_size > 0:
+ buffer.seek(0)
+ yield buffer
+
+ def _create_jsonl_row(
+ self, input: "FormattedInput", custom_id: str, **kwargs: Any
+ ) -> bytes:
+ """Creates a JSONL formatted row to be used by the OpenAI Batch API.
- Returns:
- A JSONL formatted row to be used by the OpenAI Batch API.
- """
- # TODO: depending on the format of the input, add `response_format` to the kwargs
- row = {
- "custom_id": custom_id,
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {"messages": input, **kwargs},
- }
- json_row = orjson.dumps(row)
- return json_row + b"\n"
-
- def _name_for_openai_files(self, file_no: int) -> str:
- if (
- envs.DISTILABEL_PIPELINE_NAME is None
- or envs.DISTILABEL_PIPELINE_CACHE_ID is None
- ):
- return f"distilabel-pipeline-fileno-{file_no}.jsonl"
-
- return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"
-
- @staticmethod
- def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics":
- return {
- "input_tokens": [completion.usage.prompt_tokens if completion else 0],
- "output_tokens": [completion.usage.completion_tokens if completion else 0],
- }
+ Args:
+ input: a list of inputs in chat format to generate responses for, optionally
+ including structured output.
+ custom_id: a custom ID to use for the row.
+ kwargs: the keyword arguments to use for the generation.
+
+ Returns:
+ A JSONL formatted row to be used by the OpenAI Batch API.
+ """
+ # TODO: depending on the format of the input, add `response_format` to the kwargs
+ row = {
+ "custom_id": custom_id,
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {"messages": input, **kwargs},
+ }
+ json_row = orjson.dumps(row)
+ return json_row + b"\n"
+
+ def _name_for_openai_files(self, file_no: int) -> str:
+ if (
+ envs.DISTILABEL_PIPELINE_NAME is None
+ or envs.DISTILABEL_PIPELINE_CACHE_ID is None
+ ):
+ return f"distilabel-pipeline-fileno-{file_no}.jsonl"
+
+ return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"
+
+ @staticmethod
+ def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics":
+ return {
+ "input_tokens": [completion.usage.prompt_tokens if completion else 0],
+ "output_tokens": [completion.usage.completion_tokens if completion else 0],
+ }
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
@@ -23508,21 +23568,27 @@
"top_p": top_p,
"stop": stop,
}
-
- if response_format is not None:
- kwargs["response_format"] = response_format
-
- if structured_output:
- kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
+ # Check if it's a vision generation task, in that case "stop" cannot be used or raises
+ # an error in the API.
+ if isinstance(
+ [row for row in input if row["role"] == "user"][0]["content"], list
+ ):
+ kwargs.pop("stop")
- completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
- if structured_output:
- return prepare_output(
- [completion.model_dump_json()],
- **self._get_llm_statistics(completion._raw_response),
- )
-
- return self._generations_from_openai_completion(completion)
+ if response_format is not None:
+ kwargs["response_format"] = response_format
+
+ if structured_output:
+ kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
+
+ completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
+ if structured_output:
+ return prepare_output(
+ [completion.model_dump_json()],
+ **self._get_llm_statistics(completion._raw_response),
+ )
+
+ return self._generations_from_openai_completion(completion)
src/distilabel/models/llms/openai.py