diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index f70462748..782ce9310 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union +from llama_cpp.llama_types import CreateChatCompletionResponse from pydantic import Field, PrivateAttr, validate_call from typing_extensions import TypedDict @@ -21,12 +22,17 @@ from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: + from llama_cpp import CreateChatCompletionResponse from ollama import AsyncClient - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics + from distilabel.steps.tasks.typing import ( + StandardInput, + ) # Copied from `ollama._types.Options` @@ -69,13 +75,25 @@ class Options(TypedDict, total=False): stop: Sequence[str] -class OllamaLLM(AsyncLLM): +class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin): """Ollama LLM implementation running the Async API client. Attributes: model: the model name to use for the LLM e.g. "notus". host: the Ollama server host. timeout: the timeout for the LLM. Defaults to `120`. + follow_redirects: whether to follow redirects. Defaults to `True`. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing + the tokenizer config files. If not provided, the one associated to the `model` + will be used. Defaults to `None`. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. _aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally. Set in the `load` method. @@ -112,9 +130,22 @@ class OllamaLLM(AsyncLLM): description="The structured output format to use across all the generations.", ) ) - + tokenizer_id: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The tokenizer Hugging Face Hub repo id or a path to a directory containing" + " the tokenizer config files. If not provided, the one associated to the `model`" + " will be used.", + ) + use_magpie_template: RuntimeParameter[bool] = Field( + default=False, + description="Whether to use the Magpie pre-query template or not.", + ) + magpie_pre_query_template: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The pre-query template to use for the model. Valid values are " + "`llama3`, `qwen2` or another pre-query template provided.", + ) _num_generations_param_supported = False - _aclient: Optional["AsyncClient"] = PrivateAttr(...) def load(self) -> None: @@ -135,13 +166,83 @@ def load(self) -> None: " `pip install ollama`." ) from e + if self.use_magpie_template or self.magpie_pre_query_template: + if not self.tokenizer_id: + raise ValueError( + "The Hugging Face Hub repo id or a path to a directory containing" + " the tokenizer config files is required when using the `use_magpie_template`" + " or `magpie_pre_query_template` runtime parameters." + ) + + if self.tokenizer_id: + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "Transformers is not installed. Please install it using `pip install transformers`." + ) from ie + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + async def _generate_chat_completion( + self, + input: "StandardInput", + format: Literal["", "json"] = "", + options: Union[Options, None] = None, + keep_alive: Union[bool, None] = None, + ) -> "CreateChatCompletionResponse": + return await self._aclient.chat( + model=self.model, + messages=input, + stream=False, + format=format, + options=options, + keep_alive=keep_alive, + ) + + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = ( + self._tokenizer.apply_chat_template( + conversation=input, + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + async def _generate_with_text_generation( + self, + input: "StandardInput", + format: Literal["", "json"] = None, + options: Union[Options, None] = None, + keep_alive: Union[bool, None] = None, + ) -> "CreateChatCompletionResponse": + input = self.prepare_input(input) + return await self._aclient.generate( + model=self.model, + prompt=input, + format=format, + options=options, + keep_alive=keep_alive, + ) + @validate_call - async def agenerate( # type: ignore + async def agenerate( self, input: StandardInput, format: Literal["", "json"] = "", @@ -163,15 +264,21 @@ async def agenerate( # type: ignore """ text = None try: - completion: Dict[str, Any] = await self._aclient.chat( # type: ignore - model=self.model, - messages=input, # type: ignore - stream=False, - format=format, - options=options, - keep_alive=keep_alive, - ) - text = completion["message"]["content"] + if not format: + format = None + if self.tokenizer_id is None: + completion = await self._generate_chat_completion( + input, format, options, keep_alive + ) + text = completion["message"]["content"] + else: + completion: CreateChatCompletionResponse = ( + await self._generate_with_text_generation( + input, format, options, keep_alive + ) + ) + + text = completion.response except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."