Skip to content

Commit

Permalink
add image processor to all pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 16, 2023
1 parent 7123af4 commit 2088580
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def load_bettertransformer(
tokenizer=None,
feature_extractor=None,
load_feature_extractor=None,
load_image_processor=None,
image_processor=None,
SUPPORTED_TASKS=None,
subfolder: str = "",
token: Optional[Union[bool, str]] = None,
Expand Down Expand Up @@ -209,7 +211,7 @@ def load_bettertransformer(

model = BetterTransformer.transform(model, **kwargs)

return model, model_id, tokenizer, feature_extractor
return model, model_id, tokenizer, feature_extractor, image_processor


def load_ort_pipeline(
Expand All @@ -219,6 +221,8 @@ def load_ort_pipeline(
tokenizer,
feature_extractor,
load_feature_extractor,
load_image_processor,
image_processor,
SUPPORTED_TASKS,
subfolder: str = "",
token: Optional[Union[bool, str]] = None,
Expand Down Expand Up @@ -280,7 +284,7 @@ def load_ort_pipeline(
f"""Model {model} is not supported. Please provide a valid model either as string or ORTModel.
You can also provide non model then a default one will be used"""
)
return model, model_id, tokenizer, feature_extractor
return model, model_id, tokenizer, feature_extractor, image_processor


MAPPING_LOADING_FUNC = {
Expand All @@ -293,8 +297,8 @@ def pipeline(
task: str = None,
model: Optional[Any] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
image_processor: Optional[Union[str, ImageProcessingMixin]] = None,
feature_extractor: Optional[Union[str, FeatureExtractionMixin]] = None,
image_processor: Optional[Union[str, ImageProcessingMixin]] = None,
use_fast: bool = True,
token: Optional[Union[str, bool]] = None,
accelerator: Optional[str] = "ort",
Expand Down Expand Up @@ -332,8 +336,8 @@ def pipeline(
supported_tasks = ORT_SUPPORTED_TASKS if accelerator == "ort" else TRANSFORMERS_SUPPORTED_TASKS

no_tokenizer_tasks = set()
no_image_processor = set()
no_feature_extractor_tasks = set()
no_image_processor = set()

for _task, values in supported_tasks.items():
if values["type"] == "text":
Expand All @@ -357,25 +361,25 @@ def pipeline(
else:
load_tokenizer = True

if targeted_task in no_image_processor:
load_image_processor = False
else:
load_image_processor = True

if targeted_task in no_feature_extractor_tasks:
load_feature_extractor = False
else:
load_feature_extractor = True

model, model_id, tokenizer, feature_extractor = MAPPING_LOADING_FUNC[accelerator](
if targeted_task in no_image_processor:
load_image_processor = False
else:
load_image_processor = True

model, model_id, tokenizer, feature_extractor, image_processor = MAPPING_LOADING_FUNC[accelerator](
model,
targeted_task,
load_tokenizer,
tokenizer,
load_image_processor,
image_processor,
feature_extractor,
load_feature_extractor,
load_image_processor,
image_processor,
SUPPORTED_TASKS=supported_tasks,
config=config,
hub_kwargs=hub_kwargs,
Expand All @@ -395,8 +399,8 @@ def pipeline(
task,
model=model,
tokenizer=tokenizer,
image_processor=image_processor,
feature_extractor=feature_extractor,
image_processor=image_processor,
use_fast=use_fast,
**kwargs,
)

0 comments on commit 2088580

Please sign in to comment.