Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to outlines010 #1092

Merged
merged 42 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
329b645
add outlines 0.1.0 support
davidberenstein1957 Jan 9, 2025
9dd4be9
update tests
davidberenstein1957 Jan 9, 2025
3ce1ff3
fix passing tokenizer to regex processor as well
davidberenstein1957 Jan 9, 2025
d8d7b35
fix test by specifically passing None as token to transformersllm
davidberenstein1957 Jan 9, 2025
2e0b42c
fix tests by increeasing the temperature to avoid exploding beam sear…
davidberenstein1957 Jan 9, 2025
5ee7dce
fix logit processor assignment during generation
davidberenstein1957 Jan 9, 2025
0d26a1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
47e38dc
add support transformers
davidberenstein1957 Jan 9, 2025
66ac934
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
61c3538
remove duplicate import
davidberenstein1957 Jan 9, 2025
a3b4f9c
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 9, 2025
0738b27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
8e6613b
remove duplicate
davidberenstein1957 Jan 9, 2025
7db1b0b
Merge branch 'feat/1081-feature-update-to-outlines010' of https://git…
davidberenstein1957 Jan 9, 2025
cb4c2ce
remove duplicate import
davidberenstein1957 Jan 9, 2025
7f20d9f
return content when nog chat template is present
davidberenstein1957 Jan 9, 2025
61aa597
refactor clean code
davidberenstein1957 Jan 9, 2025
b994f06
chore refactor
davidberenstein1957 Jan 9, 2025
a47963d
refactor logic if else statement
davidberenstein1957 Jan 9, 2025
a0f8acd
fix import when outlines is not present
davidberenstein1957 Jan 9, 2025
b41d6f0
chore pin transformers version
davidberenstein1957 Jan 9, 2025
d2fdd4c
chore add context w.r.t. logit processor
davidberenstein1957 Jan 9, 2025
2b8f634
chore bump version
davidberenstein1957 Jan 9, 2025
ed5f00f
add simplification of transformers implementation
davidberenstein1957 Jan 9, 2025
473de03
Update .gitignore to exclude .DS_Store files and remove vllm subproje…
davidberenstein1957 Jan 10, 2025
995e4d4
Refactor outlines version check and logits processor handling
davidberenstein1957 Jan 10, 2025
5960441
Refactor logits processor handling in LlamaCppLLM
davidberenstein1957 Jan 10, 2025
cfac574
Refactor outlines import and logits processor handling in Transformer…
davidberenstein1957 Jan 10, 2025
3378769
Refactor outlines version check and update function naming
davidberenstein1957 Jan 10, 2025
d56b6bc
Refactor processor handling in LlamaCppLLM and TransformersLLM based …
davidberenstein1957 Jan 10, 2025
110ecaf
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
4056f08
Refactor structured output return types in LlamaCppLLM, MlxLLM, and T…
davidberenstein1957 Jan 10, 2025
11a7957
Enhance MlxLLM integration and expand framework support
davidberenstein1957 Jan 10, 2025
e9fefc4
Refactor structured output handling in LlamaCppLLM and MlxLLM
davidberenstein1957 Jan 10, 2025
df24685
Refactor MlxLLM structured output handling and remove unused components
davidberenstein1957 Jan 10, 2025
65272bd
Refactor logits processor handling in TransformersLLM
davidberenstein1957 Jan 10, 2025
7fc1762
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
d2eda4e
Refactor type hints in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
85494c4
Refactor type hint imports in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
f6a50f0
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
01ea5f1
Refactor regex processor handling in prepare_guided_output function
davidberenstein1957 Jan 10, 2025
399154e
Update transformer dependency constraints in pyproject.toml
davidberenstein1957 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ venv.bak/
# Other
*.log
*.swp
.DS_Store
17 changes: 12 additions & 5 deletions src/distilabel/models/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR

Expand Down Expand Up @@ -111,6 +114,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):

_pipeline: Optional["Pipeline"] = PrivateAttr(...)
_prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
_logits_processor: Union[Callable, None] = PrivateAttr(default=None)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand Down Expand Up @@ -149,9 +153,11 @@ def load(self) -> None:
self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore

if self.structured_output:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
self.structured_output
)
processor = self._prepare_structured_output(self.structured_output)
if _is_outlines_version_below_0_1_0():
self._prefix_allowed_tokens_fn = processor
else:
self._logits_processor = [processor]

super().load()

Expand Down Expand Up @@ -232,7 +238,8 @@ def generate( # type: ignore
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
pad_token_id=self._pipeline.tokenizer.eos_token_id,
logits_processor=self._logits_processor,
)
llm_output = [
[generation["generated_text"] for generation in output]
Expand Down Expand Up @@ -292,7 +299,7 @@ def get_last_hidden_states(

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
) -> Union[Callable, List[Callable]]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
Expand Down
11 changes: 8 additions & 3 deletions src/distilabel/models/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList
from llama_cpp import (
CreateChatCompletionResponse,
Llama,
LogitsProcessor,
LogitsProcessorList,
)

from distilabel.steps.tasks.typing import FormattedInput, StandardInput

Expand Down Expand Up @@ -383,7 +388,7 @@ def generate( # type: ignore

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", None]:
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
Expand All @@ -399,4 +404,4 @@ def _prepare_structured_output(
result = prepare_guided_output(structured_output, "llamacpp", self._model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
return result["processor"]
return [result["processor"]]
52 changes: 2 additions & 50 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@
Dict,
List,
Optional,
Union,
)

from pydantic import (
Field,
PrivateAttr,
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import (
OutlinesStructuredOutputType,
StandardInput,
)

Expand All @@ -51,8 +47,6 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: the tokenizer configuration.
model_config: the model configuration.
adapter_path: the path to the adapter.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
Expand Down Expand Up @@ -82,17 +76,10 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: Dict[str, Any] = {}
model_config: Dict[str, Any] = {}
adapter_path: Optional[str] = None
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)

_mlx_generate: Optional[Callable] = PrivateAttr(default=None)
_model: Optional["nn.Module"] = PrivateAttr(...)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...)
_structured_output_logits_processor: Union[Callable, None] = PrivateAttr(
default=None
)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand All @@ -112,11 +99,6 @@ def load(self) -> None:
adapter_path=self.adapter_path,
)

if self.structured_output:
self._structured_output_logits_processor = self._prepare_structured_output(
self.structured_output
)

if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token

Expand Down Expand Up @@ -207,10 +189,6 @@ def generate(
Returns:
A list of lists of strings containing the generated responses for each input.
"""
logits_processors = []
if self._structured_output_logits_processor:
logits_processors.append(self._structured_output_logits_processor)

structured_output = None
result = []
for input in inputs:
Expand All @@ -219,13 +197,9 @@ def generate(

output: List[str] = []
for _ in range(num_generations):
if structured_output:
additional_logits_processors = self._prepare_structured_output(
structured_output
)
logits_processors.append(additional_logits_processors)
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
prompt = self.prepare_input(input)

generation = self._mlx_generate(
prompt=prompt,
model=self._model,
Expand Down Expand Up @@ -264,25 +238,3 @@ def generate(
)
)
return result

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
structured_output: the configuration dict to prepare the structured output.

Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(
structured_output, "transformers", self._pipeline
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
119 changes: 91 additions & 28 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,38 @@
Literal,
Tuple,
Type,
Union,
get_args,
)

import pkg_resources
from pydantic import BaseModel

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType
if TYPE_CHECKING: # noqa
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa

from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp", "vllm"]
"""Available frameworks for the structured output configuration. """


def _is_outlines_version_below_0_1_0() -> bool:
"""Helper function to check outlines availability and version.

Returns:
bool: True if outlines is not installed or version is below 0.1.0
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
)
version = pkg_resources.get_distribution("outlines").version
return pkg_resources.parse_version(version) < pkg_resources.parse_version("0.1.0")


def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:
Expand All @@ -45,38 +64,77 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:


def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"""Helper function to return the appropriate logits processor for the given framework."""
if framework == "transformers":
from outlines.integrations.transformers import (
JSONPrefixAllowedTokens,
RegexPrefixAllowedTokens,
"""Helper function to return the appropriate logits processors for the given framework."""
if _is_outlines_version_below_0_1_0():
processors = {
"transformers": (
"outlines.integrations.transformers",
"JSONPrefixAllowedTokens",
"RegexPrefixAllowedTokens",
),
"llamacpp": (
"outlines.integrations.llamacpp",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.integrations.vllm",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}
else:
processors = {
"transformers": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"llamacpp": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}

if framework not in processors:
raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)

return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens
module_path, json_cls, regex_cls = processors[framework]
module = importlib.import_module(module_path)
return getattr(module, json_cls), getattr(module, regex_cls)


def _get_tokenizer_from_model(
llm: Union["_vLLM", "Pipeline", "Llama"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
from outlines.integrations.llamacpp import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)
from outlines.models.llamacpp import LlamaCppTokenizer

return JSONLogitsProcessor, RegexLogitsProcessor
return LlamaCppTokenizer(llm)
if framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
if framework == "vllm":
from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor
from outlines.models.vllm import adapt_tokenizer

return JSONLogitsProcessor, RegexLogitsProcessor

raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)
return adapt_tokenizer(llm.get_tokenizer())


def prepare_guided_output(
structured_output: "OutlinesStructuredOutputType",
framework: Frameworks,
llm: Any,
llm: Union["_vLLM", "Pipeline", "Llama"],
) -> Dict[str, Any]:
"""Prepares the `LLM` to generate guided output using `outlines`.

Expand All @@ -97,10 +155,6 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install 'distilabel[outlines]'`."
)

json_processor, regex_processor = _get_logits_processor(framework)

Expand All @@ -116,18 +170,27 @@ def prepare_guided_output(
elif isinstance(schema, str):
format = "regex"

if _is_outlines_version_below_0_1_0():
# use the llm for processor initialization
model = llm
tokenizer = None
else:
# use the tokenizer for processor initialization
model = None
tokenizer = _get_tokenizer_from_model(llm, framework)

if format == "json":
return {
"processor": json_processor(
schema,
llm,
model or tokenizer,
whitespace_pattern=structured_output.get("whitespace_pattern"),
),
"schema": schema_as_dict(schema),
}

if format == "regex":
return {"processor": regex_processor(schema, llm)}
return {"processor": regex_processor(schema, model or tokenizer)}

raise DistilabelUserError(
f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
Expand Down
Loading
Loading