Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to outlines010 #1092

Merged
merged 42 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
329b645
add outlines 0.1.0 support
davidberenstein1957 Jan 9, 2025
9dd4be9
update tests
davidberenstein1957 Jan 9, 2025
3ce1ff3
fix passing tokenizer to regex processor as well
davidberenstein1957 Jan 9, 2025
d8d7b35
fix test by specifically passing None as token to transformersllm
davidberenstein1957 Jan 9, 2025
2e0b42c
fix tests by increeasing the temperature to avoid exploding beam sear…
davidberenstein1957 Jan 9, 2025
5ee7dce
fix logit processor assignment during generation
davidberenstein1957 Jan 9, 2025
0d26a1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
47e38dc
add support transformers
davidberenstein1957 Jan 9, 2025
66ac934
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
61c3538
remove duplicate import
davidberenstein1957 Jan 9, 2025
a3b4f9c
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 9, 2025
0738b27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
8e6613b
remove duplicate
davidberenstein1957 Jan 9, 2025
7db1b0b
Merge branch 'feat/1081-feature-update-to-outlines010' of https://git…
davidberenstein1957 Jan 9, 2025
cb4c2ce
remove duplicate import
davidberenstein1957 Jan 9, 2025
7f20d9f
return content when nog chat template is present
davidberenstein1957 Jan 9, 2025
61aa597
refactor clean code
davidberenstein1957 Jan 9, 2025
b994f06
chore refactor
davidberenstein1957 Jan 9, 2025
a47963d
refactor logic if else statement
davidberenstein1957 Jan 9, 2025
a0f8acd
fix import when outlines is not present
davidberenstein1957 Jan 9, 2025
b41d6f0
chore pin transformers version
davidberenstein1957 Jan 9, 2025
d2fdd4c
chore add context w.r.t. logit processor
davidberenstein1957 Jan 9, 2025
2b8f634
chore bump version
davidberenstein1957 Jan 9, 2025
ed5f00f
add simplification of transformers implementation
davidberenstein1957 Jan 9, 2025
473de03
Update .gitignore to exclude .DS_Store files and remove vllm subproje…
davidberenstein1957 Jan 10, 2025
995e4d4
Refactor outlines version check and logits processor handling
davidberenstein1957 Jan 10, 2025
5960441
Refactor logits processor handling in LlamaCppLLM
davidberenstein1957 Jan 10, 2025
cfac574
Refactor outlines import and logits processor handling in Transformer…
davidberenstein1957 Jan 10, 2025
3378769
Refactor outlines version check and update function naming
davidberenstein1957 Jan 10, 2025
d56b6bc
Refactor processor handling in LlamaCppLLM and TransformersLLM based …
davidberenstein1957 Jan 10, 2025
110ecaf
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
4056f08
Refactor structured output return types in LlamaCppLLM, MlxLLM, and T…
davidberenstein1957 Jan 10, 2025
11a7957
Enhance MlxLLM integration and expand framework support
davidberenstein1957 Jan 10, 2025
e9fefc4
Refactor structured output handling in LlamaCppLLM and MlxLLM
davidberenstein1957 Jan 10, 2025
df24685
Refactor MlxLLM structured output handling and remove unused components
davidberenstein1957 Jan 10, 2025
65272bd
Refactor logits processor handling in TransformersLLM
davidberenstein1957 Jan 10, 2025
7fc1762
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
d2eda4e
Refactor type hints in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
85494c4
Refactor type hint imports in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
f6a50f0
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
01ea5f1
Refactor regex processor handling in prepare_guided_output function
davidberenstein1957 Jan 10, 2025
399154e
Update transformer dependency constraints in pyproject.toml
davidberenstein1957 Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ venv.bak/
# Other
*.log
*.swp
.DS_Store
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ argilla = ["argilla >= 2.0.0", "ipython"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
# logit processor breaks in transformers 4.47.0
hf-transformers = ["transformers >= 4.34.1, < 4.47.0", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
Expand Down
23 changes: 17 additions & 6 deletions src/distilabel/models/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR

Expand Down Expand Up @@ -111,6 +114,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):

_pipeline: Optional["Pipeline"] = PrivateAttr(...)
_prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
_logits_processor: Union[Callable, None] = PrivateAttr(default=None)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand Down Expand Up @@ -149,9 +153,11 @@ def load(self) -> None:
self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore

if self.structured_output:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
self.structured_output
)
processor = self._prepare_structured_output(self.structured_output)
if _is_outlines_version_below_0_1_0():
self._prefix_allowed_tokens_fn = processor
else:
self._logits_processor = processor

super().load()

Expand Down Expand Up @@ -232,7 +238,8 @@ def generate( # type: ignore
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
pad_token_id=self._pipeline.tokenizer.eos_token_id,
logits_processor=self._logits_processor,
)
llm_output = [
[generation["generated_text"] for generation in output]
Expand Down Expand Up @@ -292,7 +299,7 @@ def get_last_hidden_states(

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
) -> Union[Callable, List[Callable]]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
Expand All @@ -302,6 +309,7 @@ def _prepare_structured_output(
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
prepare_guided_output,
)

Expand All @@ -310,4 +318,7 @@ def _prepare_structured_output(
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
if _is_outlines_version_below_0_1_0():
return result["processor"]
else:
return [result["processor"]]
15 changes: 12 additions & 3 deletions src/distilabel/models/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList
from llama_cpp import (
CreateChatCompletionResponse,
Llama,
LogitsProcessor,
LogitsProcessorList,
)

from distilabel.steps.tasks.typing import FormattedInput, StandardInput

Expand Down Expand Up @@ -383,7 +388,7 @@ def generate( # type: ignore

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", None]:
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
Expand All @@ -393,10 +398,14 @@ def _prepare_structured_output(
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
prepare_guided_output,
)

result = prepare_guided_output(structured_output, "llamacpp", self._model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
return result["processor"]
if _is_outlines_version_below_0_1_0():
return result["processor"]
else:
return [result["processor"]]
8 changes: 5 additions & 3 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def generate(

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
) -> Union[List[Callable], Callable]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
Expand All @@ -276,13 +276,15 @@ def _prepare_structured_output(
Returns:
The callable that will be used to guide the generation of the model.
"""
from outlines.models.mlxlm import TransformerTokenizer

from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(
structured_output, "transformers", self._pipeline
structured_output, "mlx", TransformerTokenizer(self._tokenizer._tokenizer)
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
return [result["processor"]]
128 changes: 102 additions & 26 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,39 @@
Literal,
Tuple,
Type,
Union,
get_args,
)

import pkg_resources
from pydantic import BaseModel

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict

if TYPE_CHECKING:
from llama_cpp import Llama
from outlines.models.mlxlm import TransformerTokenizer
from transformers import Pipeline
from vllm import LLM

from distilabel.steps.tasks.typing import OutlinesStructuredOutputType

Frameworks = Literal["transformers", "llamacpp", "vllm"]
"""Available frameworks for the structured output configuration. """
Frameworks = Literal["transformers", "llamacpp", "vllm", "mlx"]


def _is_outlines_version_below_0_1_0() -> bool:
"""Helper function to check outlines availability and version.

Returns:
bool: True if outlines is not installed or version is below 0.1.0
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
)
version = pkg_resources.get_distribution("outlines").version
return pkg_resources.parse_version(version) < pkg_resources.parse_version("0.1.0")


def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:
Expand All @@ -45,38 +65,89 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:


def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"""Helper function to return the appropriate logits processor for the given framework."""
if framework == "transformers":
from outlines.integrations.transformers import (
JSONPrefixAllowedTokens,
RegexPrefixAllowedTokens,
"""Helper function to return the appropriate logits processors for the given framework."""
if _is_outlines_version_below_0_1_0():
processors = {
"transformers": (
"outlines.integrations.transformers",
"JSONPrefixAllowedTokens",
"RegexPrefixAllowedTokens",
),
"llamacpp": (
"outlines.integrations.llamacpp",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.integrations.vllm",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"mlx": (
"outlines.processors.mlxlm",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}
else:
processors = {
"transformers": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"llamacpp": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"mlx": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}

if framework not in processors:
raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)

return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens
module_path, json_cls, regex_cls = processors[framework]
module = importlib.import_module(module_path)
return getattr(module, json_cls), getattr(module, regex_cls)


def _get_tokenizer_from_model(
llm: Union["LLM", "Pipeline", "Llama", "TransformerTokenizer"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
from outlines.integrations.llamacpp import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)
from outlines.models.llamacpp import LlamaCppTokenizer

return JSONLogitsProcessor, RegexLogitsProcessor
return LlamaCppTokenizer(llm)
if framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
if framework == "vllm":
from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor
from outlines.models.vllm import adapt_tokenizer

return JSONLogitsProcessor, RegexLogitsProcessor

raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)
return adapt_tokenizer(llm.get_tokenizer())
if framework == "mlx":
return llm


def prepare_guided_output(
structured_output: "OutlinesStructuredOutputType",
framework: Frameworks,
llm: Any,
llm: Union["LLM", "Pipeline", "Llama", "TransformerTokenizer"],
) -> Dict[str, Any]:
"""Prepares the `LLM` to generate guided output using `outlines`.

Expand All @@ -97,10 +168,6 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
)

json_processor, regex_processor = _get_logits_processor(framework)

Expand All @@ -116,11 +183,20 @@ def prepare_guided_output(
elif isinstance(schema, str):
format = "regex"

if _is_outlines_version_below_0_1_0():
# use the llm for processor initialization
model = llm
tokenizer = None
else:
# use the tokenizer for processor initialization
model = None
tokenizer = _get_tokenizer_from_model(llm, framework)

if format == "json":
return {
"processor": json_processor(
schema,
llm,
model or tokenizer,
whitespace_pattern=structured_output.get("whitespace_pattern"),
),
"schema": schema_as_dict(schema),
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/steps/tasks/structured_outputs/test_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from distilabel.models.llms.huggingface.transformers import TransformersLLM
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
model_to_schema,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType
Expand Down Expand Up @@ -100,9 +101,6 @@ class DummyUserTest(BaseModel):
}


@pytest.mark.skip(
reason="won't work until we update our code to work with `outlines>0.1.0`"
)
class TestOutlinesIntegration:
@pytest.mark.parametrize(
"format, schema, prompt",
Expand Down Expand Up @@ -138,7 +136,7 @@ def test_generation(
prompt = [
[{"role": "system", "content": ""}, {"role": "user", "content": prompt}]
]
result = llm.generate(prompt, max_new_tokens=30)
result = llm.generate(prompt, max_new_tokens=30, temperature=0.7)
assert isinstance(result, list)
assert isinstance(result[0], dict)
assert "generations" in result[0] and "statistics" in result[0]
Expand Down Expand Up @@ -174,6 +172,7 @@ def test_serialization(
structured_output=OutlinesStructuredOutputType(
format=format, schema=schema
),
token=None,
)
llm.load()
assert llm.dump() == dump
Expand All @@ -182,4 +181,9 @@ def test_load_from_dict(self) -> None:
llm = TransformersLLM.from_dict(DUMP_JSON)
assert isinstance(llm, TransformersLLM)
llm.load()
assert llm._prefix_allowed_tokens_fn is not None
if _is_outlines_version_below_0_1_0():
assert llm._prefix_allowed_tokens_fn is not None
assert llm._logits_processor is None
else:
assert llm._prefix_allowed_tokens_fn is None
assert llm._logits_processor is not None
Loading