diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 8a2a276d1c5..a31344ed133 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -14,5 +14,5 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand -from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand +from .export import ExecuTorchExportCommand, ExportCommand, ONNXExportCommand, TFLiteExportCommand from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/export/__init__.py b/optimum/commands/export/__init__.py index 19da68a60d2..b72cd5dbc8d 100644 --- a/optimum/commands/export/__init__.py +++ b/optimum/commands/export/__init__.py @@ -14,5 +14,6 @@ from .base import ExportCommand +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand diff --git a/optimum/commands/export/base.py b/optimum/commands/export/base.py index 07737cb8eaf..e5ed4c90ff5 100644 --- a/optimum/commands/export/base.py +++ b/optimum/commands/export/base.py @@ -15,6 +15,7 @@ """optimum.exporters command-line interface base classes.""" from .. import BaseOptimumCLICommand, CommandInfo +from .executorch import ExecuTorchExportCommand from .onnx import ONNXExportCommand from .tflite import TFLiteExportCommand @@ -25,6 +26,11 @@ class ExportCommand(BaseOptimumCLICommand): help="Export PyTorch and TensorFlow models to several format.", ) SUBCOMMANDS = ( + CommandInfo( + name="executorch", + help="Export PyTorch model to ExecuTorch.", + subcommand_class=ExecuTorchExportCommand, + ), CommandInfo( name="onnx", help="Export PyTorch and TensorFlow to ONNX.", diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py new file mode 100644 index 00000000000..adb5deae3e6 --- /dev/null +++ b/optimum/commands/export/executorch.py @@ -0,0 +1,53 @@ +"""Defines the command line for the export with ExecuTorch.""" + +from pathlib import Path +from typing import TYPE_CHECKING + +from ...exporters import TasksManager +from ..base import BaseOptimumCLICommand + + +if TYPE_CHECKING: + from argparse import ArgumentParser + + +def parse_args_executorch(parser): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." + ) + required_group.add_argument( + "--output_dir", type=Path, help="Path indicating the directory where to store the generated ExecuTorch model." + ) + + optional_group = parser.add_argument_group("Optional arguments") + optional_group.add_argument( + "--task", + default="auto", + help=( + "The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:" + f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder." + ), + ) + optional_group.add_argument( + "--recipe", + type=str, + default="xnnpack", + help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', + ) + + +class ExecuTorchExportCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + return parse_args_executorch(parser) + + def run(self): + from ...exporters.executorch import main_export + + main_export( + model_name_or_path=self.args.model, + task=self.args.task, + recipe=self.args.recipe, + output_dir=self.args.output_dir, + ) diff --git a/optimum/executorchruntime/__init__.py b/optimum/executorchruntime/__init__.py new file mode 100644 index 00000000000..3b9eec7fed5 --- /dev/null +++ b/optimum/executorchruntime/__init__.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING +from transformers.utils import _LazyModule + + +_import_structure = { + "modeling_executorch": [ + "ExecuTorchModelForCausalLM", + ], +} + +if TYPE_CHECKING: + from .modeling_executorch import ExecuTorchModelForCausalLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/optimum/executorchruntime/modeling_executorch.py b/optimum/executorchruntime/modeling_executorch.py new file mode 100644 index 00000000000..9c2001e91e7 --- /dev/null +++ b/optimum/executorchruntime/modeling_executorch.py @@ -0,0 +1,234 @@ +"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" + +import logging +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union + +import torch +from executorch.extension.pybindings.portable_lib import _load_for_executorch +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import EntryNotFoundError +from transformers import ( + AutoConfig, + AutoModel, + GenerationMixin, + AutoModelForCausalLM, + GenerationConfig, +) +from transformers.integrations.executorch import TorchExportableModuleWithStaticCache +from transformers.modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + CausalLMOutputWithPast, + ModelOutput, +) + +from ..exporters import TasksManager +from ..exporters.executorch import main_export +from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForCausalLM(OptimizedModel): + """ + ExecuTorch model with a causal language modeling head for ExecuTorch Runtime inference. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__( + self, + model: "ExecuTorchModule", + config: "PretrainedConfig", + ): + super().__init__(model, config) + self.et_model = model + print(f"DEBUG all static methods: {self.et_model.method_names()}") + self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] + self.max_seq_len = self.et_model.run_method("get_max_seq_len")[0] + self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] + self.dtype = self.et_model.run_method("get_dtype")[0] + self.bos_token_id = self.et_model.run_method("get_bos_id")[0] + self.eos_token_id = self.et_model.run_method("get_eos_id")[0] + self.vocab_size = self.et_model.run_method("get_vocab_size")[0] + + def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor: + return self.et_model.forward((input_ids, cache_position))[0] + + @classmethod + def from_pretrained( + cls, + model_dir_path: Union[str, Path], + task: str, + recipe: str, + config: "PretrainedConfig" = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + subfolder: str = "", + local_files_only: bool = False, + ) -> "ExecuTorchModelForCausalLM": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + full_path = os.path.join(f"{model_dir_path}", "model.pte") + model = _load_for_executorch(full_path) + logging.debug(f"{model.method_meta('forward')}") + return cls( + model=model, + config=config, + ) + + def _save_pretrained(self, save_directory): + """ + Saves a model weights into a directory, so that it can be re-loaded using the + [`from_pretrained`] class method. + """ + raise NotImplementedError + + @classmethod + def _export( + cls, + model_id: str, + task: str, + recipe: str, + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + subfolder: str = "", + local_files_only: bool = False, + trust_remote_code: bool = False, + ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # Export to ExecuTorch and save the pte file to the temporary directory + main_export( + model_name_or_path=model_id, + output=save_dir_path, + task=task, + recipe=recipe, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + + return cls._from_pretrained( + model_dir_path=save_dir_path, + task=task, + recipe=recipe, + config=config, + use_auth_token=use_auth_token, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ) + + def generate( + self, + prompt_tokens: List[int], + echo: bool = False, + pos_base: int = 0, + ) -> List[int]: + + self.device = torch.device("cpu") + self.max_seq_len = 256 + generated_tokens = [] + + # prefill + for i, prompt_token in enumerate(prompt_tokens): + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens = prompt_tokens + [next_token] + + while len(generated_tokens) < self.max_seq_len: + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens.append(next_token) + if next_token == self.eos_token_id: + break + + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + tokenizer: "PreTrainedTokenizer", + prompt: str, + echo: bool = True, + ) -> List[int]: + """ + Perform text completion for a prompt using the language model. + + Args: + prompt (str): Text prompt for completion. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + self.tokenizer = tokenizer + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + raise ValueError( + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + ) + + prompt_tokens = self.tokenizer.encode(prompt) + generated_tokens = self.generate( + prompt_tokens=prompt_tokens, + echo=echo, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py new file mode 100644 index 00000000000..90e228392eb --- /dev/null +++ b/optimum/exporters/executorch/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "convert": [ + "export_to_executorch", + ], + "__main__": ["main_export"], +} + +if TYPE_CHECKING: + from .__main__ import main_export + from .convert import export_to_executorch +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py new file mode 100644 index 00000000000..5c6f21258e0 --- /dev/null +++ b/optimum/exporters/executorch/__main__.py @@ -0,0 +1,138 @@ +"""Entry point to the optimum.exporters.executorch command line.""" + +import argparse +import os +import warnings +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from requests.exceptions import ConnectionError as RequestsConnectionError +from transformers import AutoConfig, AutoTokenizer +from transformers.utils import is_torch_available + +from ...commands.export.executorch import parse_args_executorch +from ...utils import logging +from ..tasks import TasksManager +from .causal_lm import * +from .convert import export_to_executorch +from .task_registry import task_registry + +if is_torch_available(): + import torch + +from typing import Optional, Union + + +logger = logging.get_logger() + + +def main_export( + model_name_or_path: str, + task: str, + recipe: str, + output_dir: Union[str, Path], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a model ID on Hugging Face Hub or a local model repository**. + + Args: + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="google/gemma-2b"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Example usage: + ```python + >>> from optimum.exporters.executorch import main_export + + >>> main_export("gemma-2b", output="gemma-2b_onnx/") + ``` + """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + print(f"DEBUG {task_registry}") + model = task_registry.get(task)(model_name_or_path, **kwargs) + + if task == "text-generation": + from transformers.integrations.executorch import TorchExportableModuleWithStaticCache + + model = TorchExportableModuleWithStaticCache(model) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + return export_to_executorch( + model=model, + task=task, + recipe=recipe, + output_dir=output_dir, + **kwargs, + ) + + +def main(): + parser = argparse.ArgumentParser("Hugging Face Optimum ExecuTorch exporter") + + parse_args_executorch(parser) + + # Retrieve CLI arguments + args = parser.parse_args() + + main_export( + model_name_or_path=args.model, + output_dir=args.output_dir, + task=args.task, + recipe=args.recipe, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + pad_token_id=args.pad_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/optimum/exporters/executorch/causal_lm.py b/optimum/exporters/executorch/causal_lm.py new file mode 100644 index 00000000000..174babbac41 --- /dev/null +++ b/optimum/exporters/executorch/causal_lm.py @@ -0,0 +1,32 @@ +from transformers import AutoModelForCausalLM, GenerationConfig + +from .task_registry import register_task + + +@register_task("text-generation") +def load_causal_lm_model(model_name_or_path: str, **kwargs): + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + cache_implementation = kwargs.get("cache_implementation", "static") + attn_implementation = kwargs.get("attn_implementation", "sdpa") + max_length = kwargs.get("max_length", 256) + print( + f"DEBUG: dtype={dtype}, max_length={max_length}, attn_implementation={attn_implementation}, cache_implementation={cache_implementation}" + ) + + return AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py new file mode 100644 index 00000000000..2d2f7907e06 --- /dev/null +++ b/optimum/exporters/executorch/convert.py @@ -0,0 +1,58 @@ +"""ExecuTorch model check and export functions.""" + +import os +from pathlib import Path +from typing import Union + +from transformers.utils import is_torch_available + +from ...utils import ( + TORCH_MINIMUM_VERSION, + check_if_transformers_greater, + is_diffusers_available, + logging, +) +from ..tasks import TasksManager +from .recipe_registry import recipe_registry +from .xnnpack import * + + +if is_torch_available(): + import torch + from transformers.modeling_utils import PreTrainedModel + + +def export_to_executorch( + model: Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"], + task: str, + recipe: str, + output_dir: Union[str, Path], + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a pre-trained PyTorch model**. This function is especially useful in case one needs to do modifications on the model, as overriding a forward call, before exporting to ExecuTorch. + + Args: + model (`Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"]`): + PyTorch model to export to ExecuTorch. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + """ + try: + recipe_func = recipe_registry.get(recipe) + except KeyError as e: + raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") + + executorch_prog = recipe_func(model, task, **kwargs) + # print(f"Exported program: {executorch_prog.exported_program().graph}") + + full_path = os.path.join(f"{output_dir}", "model.pte") + with open(full_path, "wb") as f: + executorch_prog.write_to_file(f) + print(f"Saved exported program to {full_path}") + + return executorch_prog diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py new file mode 100644 index 00000000000..5aad6454267 --- /dev/null +++ b/optimum/exporters/executorch/recipe_registry.py @@ -0,0 +1,9 @@ +recipe_registry = {} + + +def register_recipe(recipe_name): + def decorator(func): + recipe_registry[recipe_name] = func + return func + + return decorator diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py new file mode 100644 index 00000000000..8f547a573f8 --- /dev/null +++ b/optimum/exporters/executorch/task_registry.py @@ -0,0 +1,9 @@ +task_registry = {} + + +def register_task(recipe_name): + def decorator(func): + task_registry[recipe_name] = func + return func + + return decorator diff --git a/optimum/exporters/executorch/xnnpack.py b/optimum/exporters/executorch/xnnpack.py new file mode 100644 index 00000000000..32f56522dc3 --- /dev/null +++ b/optimum/exporters/executorch/xnnpack.py @@ -0,0 +1,69 @@ +from typing import Union + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + +from .recipe_registry import register_recipe + + +@register_recipe("xnnpack") +def export_to_executorch_with_xnnpack( + model: Union[PreTrainedModel, TorchExportableModuleWithStaticCache], + task: str, + **kwargs, +): + print(f"DEBUG: model={model}, task={task}, kwargs={kwargs}") + metadata = {} + if task == "text-generation": + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + metadata = { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + return {k: v for k, v in metadata.items() if v is not None} + + metadata = _get_constant_methods(model if isinstance(model, PreTrainedModel) else model.model) + else: + # TODO: Prepare model inputs for other tasks + raise ValueError(f"Unsupported task '{task}'.") + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_program = torch.export._trace._export( + model, + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + + return to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + constant_methods=metadata, + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + ), + ) diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index 7690143f13f..1e5ad6f2c3c 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -293,9 +293,28 @@ def load_ort_pipeline( return model, model_id, tokenizer, feature_extractor +def load_executorch_pipeline( + model, + targeted_task, + load_tokenizer, + tokenizer, + feature_extractor, + load_feature_extractor, + SUPPORTED_TASKS, + subfolder: str = "", + token: Optional[Union[bool, str]] = None, + revision: str = "main", + model_kwargs: Optional[Dict[str, Any]] = None, + config: AutoConfig = None, + **kwargs, +): + raise NotImplementedError("Executorch pipeline is not implemented yet.") + + MAPPING_LOADING_FUNC = { "ort": load_ort_pipeline, "bettertransformer": load_bettertransformer, + "executorch": load_executorch_pipeline, } diff --git a/setup.py b/setup.py index 82892bfcc8c..460a51816a2 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,10 @@ "datasets<=2.16", "transformers>=4.26,<4.38", ], + "exporters-executorch": [ + "executorch>=0.4.0", + "transformers>=4.46", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", "openvino": "optimum-intel[openvino]>=1.18.0", diff --git a/test_executorch.py b/test_executorch.py new file mode 100644 index 00000000000..21555ceb2df --- /dev/null +++ b/test_executorch.py @@ -0,0 +1,42 @@ +import argparse +from optimum.executorchruntime import ExecuTorchModelForCausalLM +from transformers import AutoTokenizer, pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="model repo id on huggingface.co", + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./meta_llama3_2_1b", + help="output directory to store the generated ExecuTorch model", + ) + + args = parser.parse_args() + + model_id = args.model + ourput_dir = args.output_dir + # ExecuTorchModelForCausalLM.from_pretrained(model_id) + model = ExecuTorchModelForCausalLM.from_pretrained(ourput_dir, "text-generation", "xnnpack") + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "Hey, can you tell me any fun things to do in New York?" + + # Non-pipeline path + generated_text = model.text_generation(tokenizer=tokenizer, prompt=prompt) + print(f"{generated_text}") + + # # Pipeline path + # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + # chat = [ + # {"role": "user", "content": prompt}, + # ] + # response = pipe(chat, max_new_tokens=256) + # print(f"DEBUG: generated_text: {response[0]['generated_text'][-1]['content']}")