diff --git a/outlines/generate/api.py b/outlines/generate/api.py index dea3108d5..ad01377c0 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,6 +1,6 @@ import datetime from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union from outlines.generate.generator import sequence_generator from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler @@ -534,9 +534,8 @@ def stream( class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter): def __call__( # type: ignore self, - prompts: Union[str, List[str], List[dict], List[List[dict]]], + prompts: Union[str, List[str]], media: Union[str, Any], - apply_chat_template: bool = False, max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, seed: Optional[int] = None, @@ -547,17 +546,6 @@ def __call__( # type: ignore Media: A URI to construct media or media object itself. Used as AutoProcessor argument. """ - if apply_chat_template: - # Transform the huggingface conversation object into the string that this - # model expects. - # https://huggingface.co/docs/transformers/main/en/chat_templating - prompts = ( - [self.model.processor.apply_chat_template(p) for p in prompts] - if isinstance(prompts[0], list) - else self.model.processor.apply_chat_template(prompts) - ) - prompts = cast(Union[str, List[str]], prompts) - prompts, media = self._validate_prompt_media_types(prompts, media) generation_params = self.prepare_generation_parameters( @@ -577,26 +565,14 @@ def __call__( # type: ignore def stream( # type: ignore self, - prompts: Union[str, List[str], List[dict], List[List[dict]]], + prompts: Union[str, List[str]], media: List[Union[str, Any, List[Union[str, Any]]]], - apply_chat_template: bool = False, max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, seed: Optional[int] = None, **model_specific_params, ): """Return a text generator from a prompt or a list of prompts.""" - if apply_chat_template: - # Transform the huggingface conversation object into the string that this - # model expects. - # https://huggingface.co/docs/transformers/main/en/chat_templating - prompts = ( - [self.model.processor.apply_chat_template(p) for p in prompts] - if isinstance(prompts[0], list) - else self.model.processor.apply_chat_template(prompts) - ) - prompts = cast(Union[str, List[str]], prompts) - prompts, media = self._validate_prompt_media_types(prompts, media) generation_params = self.prepare_generation_parameters( max_tokens, stop_at, seed