Skip to content

Commit

Permalink
add image processor to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 16, 2023
1 parent 05de3e3 commit 7123af4
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AudioClassificationPipeline,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutomaticSpeechRecognitionPipeline,
AutoTokenizer,
FeatureExtractionPipeline,
Expand All @@ -41,7 +42,8 @@
ZeroShotClassificationPipeline,
)
from transformers import pipeline as transformers_pipeline
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import ImageProcessingMixin
from transformers.pipelines import SUPPORTED_TASKS as TRANSFORMERS_SUPPORTED_TASKS
from transformers.pipelines import infer_framework_load_model

Expand Down Expand Up @@ -291,7 +293,8 @@ def pipeline(
task: str = None,
model: Optional[Any] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
image_processor: Optional[Union[str, ImageProcessingMixin]] = None,
feature_extractor: Optional[Union[str, FeatureExtractionMixin]] = None,
use_fast: bool = True,
token: Optional[Union[str, bool]] = None,
accelerator: Optional[str] = "ort",
Expand Down Expand Up @@ -328,16 +331,20 @@ def pipeline(

supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS

no_feature_extractor_tasks = set()
no_tokenizer_tasks = set()
no_image_processor = set()
no_feature_extractor_tasks = set()

for _task, values in supported_tasks.items():
if values["type"] == "text":
no_image_processor.add(_task)
no_feature_extractor_tasks.add(_task)
elif values["type"] in {"image", "video"}:
no_tokenizer_tasks.add(_task)
elif values["type"] in {"audio"}:
no_tokenizer_tasks.add(_task)
elif values["type"] not in ["multimodal", "audio", "video"]:
no_image_processor.add(_task)
elif values["type"] not in ["multimodal", "image", "audio", "video"]:
raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}")

# copied from transformers.pipelines.__init__.py l.609
Expand All @@ -350,6 +357,11 @@ def pipeline(
else:
load_tokenizer = True

if targeted_task in no_image_processor:
load_image_processor = False
else:
load_image_processor = True

if targeted_task in no_feature_extractor_tasks:
load_feature_extractor = False
else:
Expand All @@ -360,6 +372,8 @@ def pipeline(
targeted_task,
load_tokenizer,
tokenizer,
load_image_processor,
image_processor,
feature_extractor,
load_feature_extractor,
SUPPORTED_TASKS=supported_tasks,
Expand All @@ -374,11 +388,14 @@ def pipeline(
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, **kwargs)
if feature_extractor is None and load_feature_extractor:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **kwargs)
if image_processor is None and load_image_processor:
image_processor = AutoImageProcessor.from_pretrained(model_id, **kwargs)

return transformers_pipeline(
task,
model=model,
tokenizer=tokenizer,
image_processor=image_processor,
feature_extractor=feature_extractor,
use_fast=use_fast,
**kwargs,
Expand Down

0 comments on commit 7123af4

Please sign in to comment.