Skip to content

Commit

Permalink
Create columns with LLM returned extra keys (#1078)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Dec 12, 2024
1 parent 63c75c5 commit a8588fd
Show file tree
Hide file tree
Showing 14 changed files with 606 additions and 261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The [`LLM`][distilabel.models.llms.LLM] has an argument named `structured_output
We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output.

!!! NOTE
Take a look at [`StructuredOutputType`][distilabel.steps.tasks.structured_outputs.outlines.StructuredOutputType] to see the expected format
Take a look at [`StructuredOutputType`][distilabel.steps.tasks.typing.StructuredOutputType] to see the expected format
of the `structured_output` dict variable.

```python
Expand Down
13 changes: 12 additions & 1 deletion src/distilabel/distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,20 @@ def _get_card(
"""
sample_records = {}
for name, dataset in self.items():
sample_records[name] = (
record = (
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
)
for key, value in record.items():
# If list is too big, the `README.md` generated will be huge so we truncate it
if isinstance(value, list):
length = len(value)
if length < 10:
continue
record[key] = value[:10]
record[key].append(
f"... (truncated - showing 10 of {length} elements)"
)
sample_records[name] = record

readme_metadata = {}
if repo_id and token:
Expand Down
36 changes: 22 additions & 14 deletions src/distilabel/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def get_last_hidden_states(
)

def _prepare_structured_output(
self, structured_output: Optional["StructuredOutputType"] = None
self, structured_output: "StructuredOutputType"
) -> Union[Any, None]:
"""Method in charge of preparing the structured output generator.
Expand Down Expand Up @@ -431,7 +431,7 @@ def event_loop(self) -> "asyncio.AbstractEventLoop":
@abstractmethod
async def agenerate(
self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> List[Union[str, None]]:
) -> "GenerateOutput":
"""Method to generate a `num_generations` responses for a given input asynchronously,
and executed concurrently in `generate` method.
"""
Expand Down Expand Up @@ -591,8 +591,8 @@ def _prepare_kwargs(


def merge_responses(
responses: List[Dict[str, Any]], n: int = 1
) -> List[Dict[str, Any]]:
responses: List["GenerateOutput"], n: int = 1
) -> List["GenerateOutput"]:
"""Helper function to group the responses from `LLM.agenerate` method according
to the number of generations requested.
Expand All @@ -612,19 +612,27 @@ def chunks(lst, n):
for i in range(0, len(lst), n):
yield list(islice(lst, i, i + n))

# Split responses into groups of size n
grouped_responses = list(chunks(responses, n))
extra_keys = [
key for key in responses[0].keys() if key not in ("generations", "statistics")
]

result = []
for group in grouped_responses:
first = group[0]
for group in chunks(responses, n):
merged = {
"generations": sum((r["generations"] for r in group), []),
"statistics": {
key: sum((r["statistics"][key] for r in group), [])
for key in first["statistics"]
},
"generations": [],
"statistics": {"input_tokens": [], "output_tokens": []},
}
for response in group:
merged["generations"].append(response["generations"][0])
# Merge statistics
for key in response["statistics"]:
if key not in merged["statistics"]:
merged["statistics"][key] = []
merged["statistics"][key].append(response["statistics"][key][0])
# Merge extra keys returned by the `LLM`
for extra_key in extra_keys:
if extra_key not in merged:
merged[extra_key] = []
merged[extra_key].append(response[extra_key][0])
result.append(merged)

return result
119 changes: 94 additions & 25 deletions src/distilabel/models/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@
import random
import sys
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)

from pydantic import (
Field,
PositiveInt,
PrivateAttr,
SecretStr,
ValidationError,
Expand All @@ -31,7 +41,7 @@

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.typing import GenerateOutput, Logprob
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import (
Expand All @@ -45,12 +55,15 @@
from huggingface_hub import AsyncInferenceClient
from huggingface_hub.inference._generated.types.chat_completion import (
ChatCompletionOutput,
ChatCompletionOutputComplete,
)
from huggingface_hub.inference._generated.types.text_generation import (
TextGenerationOutput,
)
from transformers import PreTrainedTokenizer

from distilabel.models.llms.typing import Logprob


class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
"""InferenceEndpoints LLM implementation running the async API client.
Expand Down Expand Up @@ -338,15 +351,15 @@ def prepare_input(self, input: "StandardInput") -> str:

def _get_structured_output(
self, input: FormattedInput
) -> Union[Dict[str, Any], None]:
) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
"""Gets the structured output (if any) for the given input.
Args:
input: a single input in chat format to generate responses for.
Returns:
The structured output that will be passed as `grammer` to the inference endpoint
or `None` if not required.
The input and the structured output that will be passed as `grammar` to the
inference endpoint or `None` if not required.
"""
structured_output = None

Expand Down Expand Up @@ -377,7 +390,7 @@ def _get_structured_output(
"value"
].model_json_schema()

return structured_output
return input, structured_output

async def _generate_with_text_generation(
self,
Expand All @@ -387,26 +400,28 @@ async def _generate_with_text_generation(
frequency_penalty: Optional[float] = None,
temperature: float = 1.0,
do_sample: bool = False,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
typical_p: Optional[float] = None,
stop_sequences: Union[List[str], None] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
watermark: bool = False,
) -> GenerateOutput:
structured_output = self._get_structured_output(input)

completion = None
input, structured_output = self._get_structured_output(input)
prompt = self.prepare_input(input)
generation: Union["TextGenerationOutput", None] = None
try:
completion: "TextGenerationOutput" = await self._aclient.text_generation( # type: ignore
prompt=self.prepare_input(input), # type: ignore
generation = await self._aclient.text_generation( # type: ignore
prompt=prompt,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
temperature=temperature,
top_n_tokens=top_n_tokens,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
Expand All @@ -423,41 +438,62 @@ async def _generate_with_text_generation(
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)

return prepare_output(
[completion.generated_text],
input_tokens=[
compute_tokens(self.prepare_input(input), self._tokenizer.encode)
if self._tokenizer
else 0
],
generations=[generation.generated_text] if generation else [None],
input_tokens=[compute_tokens(prompt, self._tokenizer.encode)], # type: ignore
output_tokens=[
completion.details.generated_tokens if completion.details else 0
generation.details.generated_tokens
if generation and generation.details
else 0
],
logprobs=self._get_logprobs_from_text_generation(generation)
if generation
else None, # type: ignore
)

def _get_logprobs_from_text_generation(
self, generation: "TextGenerationOutput"
) -> Union[List[List[List["Logprob"]]], None]:
if generation.details is None or generation.details.top_tokens is None:
return None

return [
[
[
{"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
for top_logprob in token_logprobs
]
for token_logprobs in generation.details.top_tokens
]
]

async def _generate_with_chat_completion(
self,
input: "StandardInput",
max_new_tokens: int = 128,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: bool = False,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: float = 1.0,
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[PositiveInt] = None,
top_p: Optional[float] = None,
) -> GenerateOutput:
message = None
completion: Union["ChatCompletionOutput", None] = None
output_logprobs = None
try:
completion: "ChatCompletionOutput" = await self._aclient.chat_completion( # type: ignore
completion = await self._aclient.chat_completion( # type: ignore
messages=input, # type: ignore
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
Expand All @@ -467,25 +503,43 @@ async def _generate_with_chat_completion(
tool_choice=tool_choice, # type: ignore
tool_prompt=tool_prompt,
tools=tools, # type: ignore
top_logprobs=top_logprobs,
top_p=top_p,
)
choice = completion.choices[0]
choice = completion.choices[0] # type: ignore
if (message := choice.message.content) is None:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {choice.finish_reason}"
)
if choice_logprobs := self._get_logprobs_from_choice(choice):
output_logprobs = [choice_logprobs]
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)
return prepare_output(
[message],
input_tokens=[completion.usage.prompt_tokens],
output_tokens=[completion.usage.completion_tokens],
generations=[message],
input_tokens=[completion.usage.prompt_tokens] if completion else None,
output_tokens=[completion.usage.completion_tokens] if completion else None,
logprobs=output_logprobs,
)

def _get_logprobs_from_choice(
self, choice: "ChatCompletionOutputComplete"
) -> Union[List[List["Logprob"]], None]:
if choice.logprobs is None:
return None

return [
[
{"token": top_logprob.token, "logprob": top_logprob.logprob}
for top_logprob in token_logprobs.top_logprobs
]
for token_logprobs in choice.logprobs.content
]

def _check_stop_sequences(
self,
stop_sequences: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -517,13 +571,16 @@ async def agenerate( # type: ignore
max_new_tokens: int = 128,
frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
logit_bias: Optional[List[float]] = None,
logprobs: bool = False,
presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: float = 1.0,
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[PositiveInt] = None,
top_n_tokens: Optional[PositiveInt] = None,
top_p: Optional[float] = None,
do_sample: bool = False,
repetition_penalty: Optional[float] = None,
Expand All @@ -549,6 +606,9 @@ async def agenerate( # type: ignore
This argument is exclusive to the `chat_completion` method and will be used
only if `tokenizer_id` is `None`.
Defaults to `None`.
logprobs: whether to return the log probabilities or not. This argument is exclusive
to the `chat_completion` method and will be used only if `tokenizer_id`
is `None`. Defaults to `False`.
presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
new tokens based on whether they appear in the text so far, increasing the
model likelihood to talk about new topics. This argument is exclusive to
Expand All @@ -569,6 +629,12 @@ async def agenerate( # type: ignore
tools: a list of tools definitions that the LLM can use.
This argument is exclusive to the `chat_completion` method and will be used
only if `tokenizer_id` is `None`. Defaults to `None`.
top_logprobs: the number of top log probabilities to return per output token
generated. This argument is exclusive to the `chat_completion` method and
will be used only if `tokenizer_id` is `None`. Defaults to `None`.
top_n_tokens: the number of top log probabilities to return per output token
generated. This argument is exclusive of the `text_generation` method and
will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
do_sample: whether to use sampling for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id` is not
Expand Down Expand Up @@ -602,13 +668,15 @@ async def agenerate( # type: ignore
max_new_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
seed=seed,
stop_sequences=stop_sequences,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
)

Expand All @@ -620,6 +688,7 @@ async def agenerate( # type: ignore
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
temperature=temperature,
top_n_tokens=top_n_tokens,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
Expand Down
Loading

0 comments on commit a8588fd

Please sign in to comment.