diff --git a/.gitignore b/.gitignore index d8337200a..1aab313fb 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,4 @@ venv.bak/ # Other *.log *.swp +.DS_Store diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index 69f3d02a2..19dc32dd2 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -23,6 +23,9 @@ from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin +from distilabel.steps.tasks.structured_outputs.outlines import ( + _is_outlines_version_below_0_1_0, +) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR @@ -111,6 +114,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): _pipeline: Optional["Pipeline"] = PrivateAttr(...) _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None) + _logits_processor: Union[Callable, None] = PrivateAttr(default=None) def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, @@ -149,9 +153,11 @@ def load(self) -> None: self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore if self.structured_output: - self._prefix_allowed_tokens_fn = self._prepare_structured_output( - self.structured_output - ) + processor = self._prepare_structured_output(self.structured_output) + if _is_outlines_version_below_0_1_0(): + self._prefix_allowed_tokens_fn = processor + else: + self._logits_processor = [processor] super().load() @@ -232,7 +238,8 @@ def generate( # type: ignore do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, - pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore + pad_token_id=self._pipeline.tokenizer.eos_token_id, + logits_processor=self._logits_processor, ) llm_output = [ [generation["generated_text"] for generation in output] @@ -292,7 +299,7 @@ def get_last_hidden_states( def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union[Callable, None]: + ) -> Union[Callable, List[Callable]]: """Creates the appropriate function to filter tokens to generate structured outputs. Args: diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index 822e5cea7..a754f6b84 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -24,7 +24,12 @@ from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: - from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList + from llama_cpp import ( + CreateChatCompletionResponse, + Llama, + LogitsProcessor, + LogitsProcessorList, + ) from distilabel.steps.tasks.typing import FormattedInput, StandardInput @@ -383,7 +388,7 @@ def generate( # type: ignore def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union["LogitsProcessorList", None]: + ) -> Union["LogitsProcessorList", "LogitsProcessor"]: """Creates the appropriate function to filter tokens to generate structured outputs. Args: @@ -399,4 +404,4 @@ def _prepare_structured_output( result = prepare_guided_output(structured_output, "llamacpp", self._model) if (schema := result.get("schema")) and self.structured_output: self.structured_output["schema"] = schema - return result["processor"] + return [result["processor"]] diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index 4ffccedda..1f8c9b8c6 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -19,22 +19,18 @@ Dict, List, Optional, - Union, ) from pydantic import ( - Field, PrivateAttr, validate_call, ) -from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import ( - OutlinesStructuredOutputType, StandardInput, ) @@ -51,8 +47,6 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): tokenizer_config: the tokenizer configuration. model_config: the model configuration. adapter_path: the path to the adapter. - structured_output: a dictionary containing the structured output configuration or if more - fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. use_magpie_template: a flag used to enable/disable applying the Magpie pre-query template. Defaults to `False`. magpie_pre_query_template: the pre-query template to be applied to the prompt or @@ -82,17 +76,10 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): tokenizer_config: Dict[str, Any] = {} model_config: Dict[str, Any] = {} adapter_path: Optional[str] = None - structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( - default=None, - description="The structured output format to use across all the generations.", - ) _mlx_generate: Optional[Callable] = PrivateAttr(default=None) _model: Optional["nn.Module"] = PrivateAttr(...) _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...) - _structured_output_logits_processor: Union[Callable, None] = PrivateAttr( - default=None - ) def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, @@ -112,11 +99,6 @@ def load(self) -> None: adapter_path=self.adapter_path, ) - if self.structured_output: - self._structured_output_logits_processor = self._prepare_structured_output( - self.structured_output - ) - if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token @@ -207,10 +189,6 @@ def generate( Returns: A list of lists of strings containing the generated responses for each input. """ - logits_processors = [] - if self._structured_output_logits_processor: - logits_processors.append(self._structured_output_logits_processor) - structured_output = None result = [] for input in inputs: @@ -219,13 +197,9 @@ def generate( output: List[str] = [] for _ in range(num_generations): - if structured_output: - additional_logits_processors = self._prepare_structured_output( - structured_output - ) - logits_processors.append(additional_logits_processors) + if structured_output: # will raise a NotImplementedError + self._prepare_structured_output(structured_output) prompt = self.prepare_input(input) - generation = self._mlx_generate( prompt=prompt, model=self._model, @@ -264,25 +238,3 @@ def generate( ) ) return result - - def _prepare_structured_output( - self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union[Callable, None]: - """Creates the appropriate function to filter tokens to generate structured outputs. - - Args: - structured_output: the configuration dict to prepare the structured output. - - Returns: - The callable that will be used to guide the generation of the model. - """ - from distilabel.steps.tasks.structured_outputs.outlines import ( - prepare_guided_output, - ) - - result = prepare_guided_output( - structured_output, "transformers", self._pipeline - ) - if schema := result.get("schema"): - self.structured_output["schema"] = schema - return result["processor"] diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index b8ac03641..a5aceacb3 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -24,19 +24,38 @@ Literal, Tuple, Type, + Union, get_args, ) +import pkg_resources from pydantic import BaseModel from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict -if TYPE_CHECKING: - from distilabel.steps.tasks.typing import OutlinesStructuredOutputType +if TYPE_CHECKING: # noqa + from llama_cpp import Llama # noqa + from transformers import Pipeline # noqa + from vllm import LLM as _vLLM # noqa + + from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa Frameworks = Literal["transformers", "llamacpp", "vllm"] -"""Available frameworks for the structured output configuration. """ + + +def _is_outlines_version_below_0_1_0() -> bool: + """Helper function to check outlines availability and version. + + Returns: + bool: True if outlines is not installed or version is below 0.1.0 + """ + if not importlib.util.find_spec("outlines"): + raise ImportError( + "Outlines is not installed. Please install it using `pip install outlines`." + ) + version = pkg_resources.get_distribution("outlines").version + return pkg_resources.parse_version(version) < pkg_resources.parse_version("0.1.0") def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: @@ -45,38 +64,77 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: - """Helper function to return the appropriate logits processor for the given framework.""" - if framework == "transformers": - from outlines.integrations.transformers import ( - JSONPrefixAllowedTokens, - RegexPrefixAllowedTokens, + """Helper function to return the appropriate logits processors for the given framework.""" + if _is_outlines_version_below_0_1_0(): + processors = { + "transformers": ( + "outlines.integrations.transformers", + "JSONPrefixAllowedTokens", + "RegexPrefixAllowedTokens", + ), + "llamacpp": ( + "outlines.integrations.llamacpp", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "vllm": ( + "outlines.integrations.vllm", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + } + else: + processors = { + "transformers": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "llamacpp": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "vllm": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + } + + if framework not in processors: + raise DistilabelUserError( + f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}", + page="sections/how_to_guides/advanced/structured_generation/", ) - return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens + module_path, json_cls, regex_cls = processors[framework] + module = importlib.import_module(module_path) + return getattr(module, json_cls), getattr(module, regex_cls) + +def _get_tokenizer_from_model( + llm: Union["_vLLM", "Pipeline", "Llama"], + framework: Frameworks, +) -> Callable: if framework == "llamacpp": - from outlines.integrations.llamacpp import ( - JSONLogitsProcessor, - RegexLogitsProcessor, - ) + from outlines.models.llamacpp import LlamaCppTokenizer - return JSONLogitsProcessor, RegexLogitsProcessor + return LlamaCppTokenizer(llm) + if framework == "transformers": + from outlines.models.transformers import TransformerTokenizer + return TransformerTokenizer(llm.tokenizer) if framework == "vllm": - from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor + from outlines.models.vllm import adapt_tokenizer - return JSONLogitsProcessor, RegexLogitsProcessor - - raise DistilabelUserError( - f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}", - page="sections/how_to_guides/advanced/structured_generation/", - ) + return adapt_tokenizer(llm.get_tokenizer()) def prepare_guided_output( structured_output: "OutlinesStructuredOutputType", framework: Frameworks, - llm: Any, + llm: Union["_vLLM", "Pipeline", "Llama"], ) -> Dict[str, Any]: """Prepares the `LLM` to generate guided output using `outlines`. @@ -97,10 +155,6 @@ def prepare_guided_output( case of "json" will also include the schema as a dict, to simplify serialization and deserialization. """ - if not importlib.util.find_spec("outlines"): - raise ImportError( - "Outlines is not installed. Please install it using `pip install 'distilabel[outlines]'`." - ) json_processor, regex_processor = _get_logits_processor(framework) @@ -116,18 +170,27 @@ def prepare_guided_output( elif isinstance(schema, str): format = "regex" + if _is_outlines_version_below_0_1_0(): + # use the llm for processor initialization + model = llm + tokenizer = None + else: + # use the tokenizer for processor initialization + model = None + tokenizer = _get_tokenizer_from_model(llm, framework) + if format == "json": return { "processor": json_processor( schema, - llm, + model or tokenizer, whitespace_pattern=structured_output.get("whitespace_pattern"), ), "schema": schema_as_dict(schema), } if format == "regex": - return {"processor": regex_processor(schema, llm)} + return {"processor": regex_processor(schema, model or tokenizer)} raise DistilabelUserError( f"Invalid format '{format}'. Must be either 'json' or 'regex'.", diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index e4eb2025c..2812c2e48 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -19,6 +19,7 @@ from distilabel.models.llms.huggingface.transformers import TransformersLLM from distilabel.steps.tasks.structured_outputs.outlines import ( + _is_outlines_version_below_0_1_0, model_to_schema, ) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType @@ -100,9 +101,6 @@ class DummyUserTest(BaseModel): } -@pytest.mark.skip( - reason="won't work until we update our code to work with `outlines>0.1.0`" -) class TestOutlinesIntegration: @pytest.mark.parametrize( "format, schema, prompt", @@ -138,7 +136,7 @@ def test_generation( prompt = [ [{"role": "system", "content": ""}, {"role": "user", "content": prompt}] ] - result = llm.generate(prompt, max_new_tokens=30) + result = llm.generate(prompt, max_new_tokens=30, temperature=0.7) assert isinstance(result, list) assert isinstance(result[0], dict) assert "generations" in result[0] and "statistics" in result[0] @@ -174,6 +172,7 @@ def test_serialization( structured_output=OutlinesStructuredOutputType( format=format, schema=schema ), + token=None, ) llm.load() assert llm.dump() == dump @@ -182,4 +181,9 @@ def test_load_from_dict(self) -> None: llm = TransformersLLM.from_dict(DUMP_JSON) assert isinstance(llm, TransformersLLM) llm.load() - assert llm._prefix_allowed_tokens_fn is not None + if _is_outlines_version_below_0_1_0(): + assert llm._prefix_allowed_tokens_fn is not None + assert llm._logits_processor is None + else: + assert llm._prefix_allowed_tokens_fn is None + assert llm._logits_processor is not None