Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add magpie support llama cpp ollama #1084

Closed
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from anthropic import AsyncAnthropic
from anthropic.types import Message

from distilabel.llms.typing import LLMStatistics
from distilabel.models.llms.typing import LLMStatistics


_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from pydantic import BaseModel
from tokenizers import Tokenizer

from distilabel.llms.typing import LLMStatistics
from distilabel.models.llms.typing import LLMStatistics


_COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY"
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from groq import AsyncGroq
from groq.types.chat.chat_completion import ChatCompletion

from distilabel.llms.typing import LLMStatistics
from distilabel.models.llms.typing import LLMStatistics


_GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL"
Expand Down
172 changes: 155 additions & 17 deletions src/distilabel/models/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import Field, FilePath, PrivateAttr, validate_call
from pydantic import Field, FilePath, PrivateAttr, model_validator, validate_call

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList

from distilabel.steps.tasks.typing import (
FormattedInput,
StandardInput,
)


class LlamaCppLLM(LLM):
class LlamaCppLLM(LLM, MagpieChatTemplateMixin):
"""llama.cpp LLM implementation running the Python bindings for the C++ code.

Attributes:
Expand All @@ -44,6 +50,15 @@ class LlamaCppLLM(LLM):
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`Llama` class of `llama_cpp` library. Defaults to `{}`.
tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer config files. If not provided, the one associated to the `model`
will be used. Defaults to `None`.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
_model: the Llama model instance. This attribute is meant to be used internally and
should not be accessed directly. It will be set in the `load` method.

Expand Down Expand Up @@ -140,10 +155,36 @@ class User(BaseModel):
default=None,
description="The structured output format to use across all the generations.",
)

tokenizer_id: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The tokenizer Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files. If not provided, the one associated to the `model`"
" will be used.",
)
use_magpie_template: RuntimeParameter[bool] = Field(
default=False,
description="Whether to use the Magpie pre-query template or not.",
)
magpie_pre_query_template: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The pre-query template to use for the model. Valid values are "
"`llama3`, `qwen2` or another pre-query template provided.",
)
_logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)
_model: Optional["Llama"] = PrivateAttr(...)

@model_validator(mode="after") # type: ignore
def validate_magpie_usage(
self,
) -> "LlamaCppLLM":
"""Validates that magpie usage is valid."""

if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)

def load(self) -> None:
"""Loads the `Llama` model from the `model_path`."""
try:
Expand All @@ -169,6 +210,27 @@ def load(self) -> None:
self.structured_output
)

if self.use_magpie_template or self.magpie_pre_query_template:
if not self.tokenizer_id:
raise ValueError(
"The Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files is required when using the `use_magpie_template`"
" or `magpie_pre_query_template` runtime parameters."
)

if self.tokenizer_id:
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install transformers`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
if self._tokenizer.chat_template is None:
raise ValueError(
"The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
)

# NOTE: Here because of the custom `logging` interface used, since it will create the logging name
# out of the model name, which won't be available until the `Llama` instance is created.
super().load()
Expand All @@ -178,6 +240,70 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self._model.model_path # type: ignore

def _generate_chat_completion(
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> "CreateChatCompletionResponse":
return self._model.create_chat_completion( # type: ignore
messages=input, # type: ignore
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
logits_processor=self._logits_processor,
**(extra_generation_kwargs or {}),
)

def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Args:
input: the input list containing chat items.

Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
conversation=input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

def _generate_with_text_generation(
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> "CreateChatCompletionResponse":
prompt = self.prepare_input(input)
return self._model.create_completion(
prompt=prompt,
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
logits_processor=self._logits_processor,
**(extra_generation_kwargs or {}),
)

@validate_call
def generate( # type: ignore
self,
Expand Down Expand Up @@ -230,24 +356,36 @@ def generate( # type: ignore
self._logits_processor = self._prepare_structured_output(
structured_output
)
chat_completions: "CreateChatCompletionResponse" = (
self._model.create_chat_completion( # type: ignore
messages=input, # type: ignore
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
logits_processor=self._logits_processor,
**(extra_generation_kwargs or {}),
if self.tokenizer_id is None:
completion = self._generate_chat_completion(
input,
max_new_tokens,
frequency_penalty,
presence_penalty,
temperature,
top_p,
extra_generation_kwargs,
)
)
outputs.append(chat_completions["choices"][0]["message"]["content"])
output_tokens.append(chat_completions["usage"]["completion_tokens"])
outputs.append(completion["choices"][0]["message"]["content"])
output_tokens.append(completion["usage"]["completion_tokens"])
else:
completion: "CreateChatCompletionResponse" = (
self._generate_with_text_generation( # type: ignore
input,
max_new_tokens,
frequency_penalty,
presence_penalty,
temperature,
top_p,
extra_generation_kwargs,
)
)
outputs.append(completion["choices"][0]["text"])
output_tokens.append(completion["usage"]["completion_tokens"])
batch_outputs.append(
prepare_output(
outputs,
input_tokens=[chat_completions["usage"]["prompt_tokens"]]
input_tokens=[completion["usage"]["prompt_tokens"]]
* num_generations,
output_tokens=output_tokens,
)
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mistralai import Mistral
from mistralai.models.chatcompletionresponse import ChatCompletionResponse

from distilabel.llms.typing import LLMStatistics
from distilabel.models.llms.typing import LLMStatistics


_MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY"
Expand Down
Loading
Loading