Skip to content

Commit

Permalink
Fix parsing LLM generation kwargs (#537)
Browse files Browse the repository at this point in the history
* Add `validate_call` to coerce `generation_kwargs`

* Fix `validate_call` needs imports

* Fix `Ollama.generate` unit test

* Include note from where it was copied

* Update docs/sections/learn/llms/index.md

Co-authored-by: Alvaro Bartolome <[email protected]>

* Add note about `validate_call`

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
gabrielmbmb and alvarobartt authored Apr 16, 2024
1 parent be11145 commit ef254e4
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 80 deletions.
11 changes: 10 additions & 1 deletion docs/sections/learn/llms/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Once those methods have been implemented, then the custom LLM will be ready to b
```python
from typing import Any

from pydantic import validate_call

from distilabel.llms import AsyncLLM, LLM
from distilabel.llms.typing import GenerateOutput, HiddenState
from distilabel.steps.tasks.typing import ChatType
Expand All @@ -107,7 +109,8 @@ class CustomLLM(LLM):
def model_name(self) -> str:
return "my-model"

def generate(self, inputs: List[ChatType], num_generations: int = 1) -> List[GenerateOutput]:
@validate_call
def generate(self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any) -> List[GenerateOutput]:
for _ in range(num_generations):
...

Expand All @@ -120,6 +123,7 @@ class CustomAsyncLLM(AsyncLLM):
def model_name(self) -> str:
return "my-model"

@validate_call
async def agenerate(self, input: ChatType, num_generations: int = 1, **kwargs: Any) -> GenerateOutput:
for _ in range(num_generations):
...
Expand All @@ -128,6 +132,11 @@ class CustomAsyncLLM(AsyncLLM):
...
```

`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method.

!!! NOTE
To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings.

## Available LLMs

Here's a list with the available LLMs that can be used within the `distilabel` library:
Expand Down
15 changes: 8 additions & 7 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@
)

from httpx import AsyncClient
from pydantic import Field, PrivateAttr, SecretStr
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
from anthropic import AsyncAnthropic

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType

_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"

Expand Down Expand Up @@ -149,15 +149,16 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

@validate_call
async def agenerate( # type: ignore
self,
input: "ChatType",
input: ChatType,
max_tokens: int = 128,
stop_sequences: Union[List[str], None] = None,
temperature: float = 1.0,
top_p: Union[float, None] = None,
top_k: Union[int, None] = None,
) -> "GenerateOutput":
) -> GenerateOutput:
"""Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).
Args:
Expand All @@ -173,14 +174,14 @@ async def agenerate( # type: ignore
"""
from anthropic._types import NOT_GIVEN

completion = await self._aclient.messages.create(
completion = await self._aclient.messages.create( # type: ignore
model=self.model,
system=(
input.pop(0)["content"]
if input and input[0]["role"] == "system"
else NOT_GIVEN
),
messages=input,
messages=input, # type: ignore
max_tokens=max_tokens,
stream=False,
stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences,
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
Union,
)

from pydantic import Field, PrivateAttr, SecretStr
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
from cohere import AsyncClient, ChatMessage

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType

_COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY"

Expand Down Expand Up @@ -153,10 +153,10 @@ def _format_chat_to_cohere(

return system, chat_history, message

@override
@validate_call
async def agenerate( # type: ignore
self,
input: "ChatType",
input: ChatType,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
Expand Down
18 changes: 13 additions & 5 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,27 @@
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union

from pydantic import Field, PrivateAttr, SecretStr, ValidationError, model_validator
from pydantic import (
Field,
PrivateAttr,
SecretStr,
ValidationError,
model_validator,
validate_call,
)
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
from huggingface_hub import AsyncInferenceClient
from openai import AsyncOpenAI
from transformers import PreTrainedTokenizer

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType

_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME = "HF_TOKEN"

Expand Down Expand Up @@ -220,7 +227,7 @@ async def _openai_agenerate(
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: Optional[float] = None,
) -> "GenerateOutput":
) -> GenerateOutput:
"""Generates completions for the given input using the OpenAI async client."""
completion = await self._aclient.chat.completions.create( # type: ignore
messages=input, # type: ignore
Expand All @@ -241,9 +248,10 @@ async def _openai_agenerate(
return [completion.choices[0].message.content]

# TODO: add `num_generations` parameter once either TGI or `AsyncInferenceClient` allows `n` parameter
@validate_call
async def agenerate( # type: ignore
self,
input: "ChatType",
input: ChatType,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
Expand Down
12 changes: 7 additions & 5 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import PrivateAttr
from pydantic import PrivateAttr, validate_call

from distilabel.llms.base import LLM
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType

if TYPE_CHECKING:
from transformers import Pipeline
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

from distilabel.llms.typing import GenerateOutput, HiddenState
from distilabel.steps.tasks.typing import ChatType
from distilabel.llms.typing import HiddenState


class TransformersLLM(LLM, CudaDevicePlacementMixin):
Expand Down Expand Up @@ -129,17 +130,18 @@ def prepare_input(self, input: "ChatType") -> str:
add_generation_prompt=True,
)

@validate_call
def generate( # type: ignore
self,
inputs: List["ChatType"],
inputs: List[ChatType],
num_generations: int = 1,
max_new_tokens: int = 128,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
top_p: float = 1.0,
top_k: int = 0,
do_sample: bool = True,
) -> List["GenerateOutput"]:
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input using the text generation
pipeline.
Expand Down
12 changes: 6 additions & 6 deletions src/distilabel/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
import logging
from typing import TYPE_CHECKING, Callable, List, Optional, Union

from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, validate_call

from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType

if TYPE_CHECKING:
from litellm import Choices

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType


class LiteLLM(AsyncLLM):
"""LiteLLM implementation running the async API client.
Expand Down Expand Up @@ -74,9 +73,10 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

@validate_call
async def agenerate( # type: ignore
self,
input: "ChatType",
input: ChatType,
num_generations: int = 1,
functions: Optional[List] = None,
function_call: Optional[str] = None,
Expand All @@ -96,7 +96,7 @@ async def agenerate( # type: ignore
mock_response: Optional[str] = None,
force_timeout: Optional[int] = 600,
custom_llm_provider: Optional[str] = None,
) -> "GenerateOutput":
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).
Args:
Expand Down
12 changes: 6 additions & 6 deletions src/distilabel/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

from typing import TYPE_CHECKING, List, Optional

from pydantic import Field, FilePath, PrivateAttr
from pydantic import Field, FilePath, PrivateAttr, validate_call

from distilabel.llms.base import LLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType


class LlamaCppLLM(LLM):
"""llama.cpp LLM implementation running the Python bindings for the C++ code.
Expand Down Expand Up @@ -84,16 +83,17 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self._model.model_path # type: ignore

@validate_call
def generate( # type: ignore
self,
inputs: List["ChatType"],
inputs: List[ChatType],
num_generations: int = 1,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
) -> List["GenerateOutput"]:
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for the given input using the Llama model.
Args:
Expand Down
11 changes: 6 additions & 5 deletions src/distilabel/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import os
from typing import TYPE_CHECKING, Any, List, Optional

from pydantic import Field, PrivateAttr, SecretStr
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
from mistralai.async_client import MistralAsyncClient

from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType

_MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY"

Expand Down Expand Up @@ -113,13 +113,14 @@ def model_name(self) -> str:
return self.model

# TODO: add `num_generations` parameter once Mistral client allows `n` parameter
@validate_call
async def agenerate( # type: ignore
self,
input: "ChatType",
input: ChatType,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> "GenerateOutput":
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the MistralAI async
client.
Expand Down
Loading

0 comments on commit ef254e4

Please sign in to comment.