Skip to content

Commit

Permalink
add ollama support magpie
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Dec 19, 2024
1 parent ea6af59 commit c6708c9
Showing 1 changed file with 121 additions and 14 deletions.
135 changes: 121 additions & 14 deletions src/distilabel/models/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@

from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union

from llama_cpp.llama_types import CreateChatCompletionResponse
from pydantic import Field, PrivateAttr, validate_call
from typing_extensions import TypedDict

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse
from ollama import AsyncClient

from distilabel.llms.typing import LLMStatistics
from distilabel.models.llms.typing import LLMStatistics
from distilabel.steps.tasks.typing import (
StandardInput,
)


# Copied from `ollama._types.Options`
Expand Down Expand Up @@ -69,13 +75,25 @@ class Options(TypedDict, total=False):
stop: Sequence[str]


class OllamaLLM(AsyncLLM):
class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin):
"""Ollama LLM implementation running the Async API client.
Attributes:
model: the model name to use for the LLM e.g. "notus".
host: the Ollama server host.
timeout: the timeout for the LLM. Defaults to `120`.
follow_redirects: whether to follow redirects. Defaults to `True`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer config files. If not provided, the one associated to the `model`
will be used. Defaults to `None`.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
_aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally.
Set in the `load` method.
Expand Down Expand Up @@ -112,9 +130,22 @@ class OllamaLLM(AsyncLLM):
description="The structured output format to use across all the generations.",
)
)

tokenizer_id: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The tokenizer Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files. If not provided, the one associated to the `model`"
" will be used.",
)
use_magpie_template: RuntimeParameter[bool] = Field(
default=False,
description="Whether to use the Magpie pre-query template or not.",
)
magpie_pre_query_template: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The pre-query template to use for the model. Valid values are "
"`llama3`, `qwen2` or another pre-query template provided.",
)
_num_generations_param_supported = False

_aclient: Optional["AsyncClient"] = PrivateAttr(...)

def load(self) -> None:
Expand All @@ -135,13 +166,83 @@ def load(self) -> None:
" `pip install ollama`."
) from e

if self.use_magpie_template or self.magpie_pre_query_template:
if not self.tokenizer_id:
raise ValueError(
"The Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files is required when using the `use_magpie_template`"
" or `magpie_pre_query_template` runtime parameters."
)

if self.tokenizer_id:
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install transformers`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

async def _generate_chat_completion(
self,
input: "StandardInput",
format: Literal["", "json"] = "",
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "CreateChatCompletionResponse":
return await self._aclient.chat(
model=self.model,
messages=input,
stream=False,
format=format,
options=options,
keep_alive=keep_alive,
)

def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template(
conversation=input,
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

async def _generate_with_text_generation(
self,
input: "StandardInput",
format: Literal["", "json"] = None,
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "CreateChatCompletionResponse":
input = self.prepare_input(input)
return await self._aclient.generate(
model=self.model,
prompt=input,
format=format,
options=options,
keep_alive=keep_alive,
)

@validate_call
async def agenerate( # type: ignore
async def agenerate(
self,
input: StandardInput,
format: Literal["", "json"] = "",
Expand All @@ -163,15 +264,21 @@ async def agenerate( # type: ignore
"""
text = None
try:
completion: Dict[str, Any] = await self._aclient.chat( # type: ignore
model=self.model,
messages=input, # type: ignore
stream=False,
format=format,
options=options,
keep_alive=keep_alive,
)
text = completion["message"]["content"]
if not format:
format = None
if self.tokenizer_id is None:
completion = await self._generate_chat_completion(
input, format, options, keep_alive
)
text = completion["message"]["content"]
else:
completion: CreateChatCompletionResponse = (
await self._generate_with_text_generation(
input, format, options, keep_alive
)
)

text = completion.response
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
Expand Down

0 comments on commit c6708c9

Please sign in to comment.