From afe520975fed8bda850df86129eb8937e1bc5e03 Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 12 Nov 2024 22:38:35 -0600 Subject: [PATCH] rebased to main and resolved all conflicts Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 13 +- .../transformers/models/modeling_auto.py | 288 ++++-------------- tests/spd/test_tlm_dlm_export_and_compile.py | 50 +-- 3 files changed, 92 insertions(+), 259 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 88c2c155b..e4327dd41 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -132,7 +132,12 @@ def _export( """ export_dir = Path(export_dir or (QEFF_HOME / self.model_name)) export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash) - onnx_path = export_dir / f"{self.model_name}.onnx" + if self.num_speculative_tokens: + model_name = f"{self.model_name}_{self.num_speculative_tokens+1}nltk.onnx" + else: + model_name = f"{self.model_name}.onnx" + onnx_path = export_dir / model_name + # TODO: need to add hash to onnx if onnx_path.is_file(): self.onnx_path = onnx_path return onnx_path @@ -244,6 +249,12 @@ def _compile( if mdp_ts_num_devices > 1: compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices})) + if self.num_speculative_tokens: + compile_hash.update(to_hashable({"num_speculative_tokens": self.num_speculative_tokens})) + + if self.is_dlm: + compile_hash.update(to_hashable({"is_dlm": self.is_dlm})) + # Check if already compiled compile_hash = compile_hash.hexdigest()[:16] qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index e0b802375..10ae844f6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import math import hashlib import logging import warnings @@ -16,15 +17,13 @@ from transformers import AutoModel, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast import QEfficient -from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime -from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform, SpDTransform +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform -from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig -from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig -from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer -from QEfficient.utils.constants import QEFF_MODELS_DIR -from QEfficient.utils.logging_utils import logger +from QEfficient.utils import constants, get_padding_shape_from_config +from QEfficient.utils.cache import to_hashable logger = logging.getLogger(__file__) @@ -47,29 +46,6 @@ def __init__(self, model: nn.Module) -> None: raise AssertionError("Please use `from_pretrained` method to load quantized models") super().__init__(model) - self.model.config.use_cache = ( - True # Always pass use_cache = True, to get KV values as output during ONNX export - ) - self.pretrained_model_name_or_path = pretrained_model_name_or_path - - # Set model card name, which is used to decide ONNX, QPC files path during export and compile resp. - if model_card_name := kwargs.pop("model_card_name", None): - self.model_card_name = model_card_name - elif os.path.isdir(self.pretrained_model_name_or_path): - hash_object = hashlib.sha256() - hash_object.update(self.pretrained_model_name_or_path.encode("utf-8")) - self.model_card_name = hash_object.hexdigest() - else: - self.model_card_name = self.pretrained_model_name_or_path - - self.full_batch_size = kwargs.get("full_batch_size", None) - self.num_speculative_tokens = kwargs.get("num_speculative_tokens", None) - self.is_dlm = kwargs.get("is_dlm", False) - self.kwargs = kwargs - self._tokenizer = None - self.is_transformed = False - if kwargs.get("transform", True): - self.transform(**kwargs) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -135,7 +111,7 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs): + def __init__(self, model: nn.Module, continuous_batching: bool = False, num_speculative_tokens: Optional[int] = None, is_dlm: bool = False, **kwargs): if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -148,6 +124,8 @@ def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching + self.num_speculative_tokens = num_speculative_tokens + self.is_dlm = is_dlm @classmethod def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, *args, **kwargs): @@ -174,36 +152,25 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo model.generate(prompts=["Hi there!!"]) """ + num_speculative_tokens = kwargs.pop("num_speculative_tokens", None) + is_dlm = kwargs.pop("is_dlm", False) + if num_speculative_tokens is not None: + if not isinstance(num_speculative_tokens, int) or num_speculative_tokens<2: + ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") + if is_dlm: + raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") + cls._pytorch_transforms.append(SpDTransform) if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) - num_speculative_tokens = kwargs.pop("num_speculative_tokens", None) - is_dlm = kwargs.pop("is_dlm", False) - - attn_implementation = kwargs.get("attn_implementation", None) - if attn_implementation != "eager": - logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}") - kwargs.update({"attn_implementation": "eager"}) - - if low_cpu_mem_usage := kwargs.get("low_cpu_mem_usage", None): - logger.warning(f"Updating low_cpu_mem_usage to be 'False', got {low_cpu_mem_usage}") - kwargs.update({"low_cpu_mem_usage": False}) - - model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained( - pretrained_model_name_or_path, *args, **kwargs - ) - return cls( - model, - pretrained_model_name_or_path=pretrained_model_name_or_path, - model_card_name=model_card_name, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - is_dlm=is_dlm, - **kwargs, - ) + self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + self.continuous_batching = continuous_batching + self.num_speculative_tokens = num_speculative_tokens + self.is_dlm = is_dlm + return self @property def model_hash(self) -> str: @@ -215,90 +182,7 @@ def model_hash(self) -> str: mhash = mhash.hexdigest()[:16] return mhash - Returns: - :Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: Tokenizer from ``transformers`` for the given model. - """ - if self._tokenizer is None: - self._tokenizer = self.get_tokenizer() - return self._tokenizer - - def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=self.pretrained_model_name_or_path, **self.kwargs) - return tokenizer - - -class QEFFAutoModelForCausalLM(QEFFTransformersBase): - """ - The QEFF class is designed for manipulating any causal language model from the HuggingFace hub. - Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization. - Please note that the QEFF class is also a part of the ``QEfficient`` module. - - ``Mandatory`` Args: - :model (nn.Module): PyTorch model - :pretrained_model_name_or_path (str): We recommend passing name of the model as input here, as you are not using `from_pretrained` method. This name will be used for deciding path of the ``ONNX/qpc`` files generated during ``export``, ``compilation`` stages. - - .. code-block:: python - - from QEfficient import QEFFAutoModelForCausalLM - - """ - - _pytorch_transforms = [CustomOpsTransform, KVCacheTransform] - - def transform( - self, - num_speculative_tokens: Optional[int] = None, - is_dlm: bool = False, - **kwargs): - """ - This method applies all relevant optimization transforms on the model and toggles the ``self.is_transformed`` attribute to True. If the model is already transformed, the method will simply return. - Please note that this method does not require any input arguments." - - ``Optional`` Args: - :num_speculative_tokens (int, optional): Number of speculative tokens, specified only for TLM SpD model. - :is_dlm (bool): True if this is a DLM SpD model. - - Returns: - :obj: Same object with transformed ``self.model`` - """ - if self.is_transformed: - return - if self.full_batch_size is not None: - if KVCacheTransform in self._pytorch_transforms: - self._pytorch_transforms[self._pytorch_transforms.index(KVCacheTransform)] = CBTransform - if CBTransform not in self._pytorch_transforms: - raise RuntimeError("please don't update _pytorch_transforms variable") - else: - if CBTransform in self._pytorch_transforms: - self._pytorch_transforms[self._pytorch_transforms.index(CBTransform)] = KVCacheTransform - if KVCacheTransform not in self._pytorch_transforms: - raise RuntimeError("Please don't update _pytorch_transforms variable") - - # Update list of pytorch transforms if the model falls in AWQ/GPTQ category - if hasattr(self.model.config, "quantization_config"): - if isinstance(self.model.config.quantization_config, QEffAwqConfig): - self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform) - - if isinstance(self.model.config.quantization_config, QEffGPTQConfig): - self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform) - - if num_speculative_tokens is not None: - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens<2: - ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") - if is_dlm: - raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") - self._pytorch_transforms.append(SpDTransform) - elif is_dlm: - setattr(self.model, "is_dlm", True) - - for transform in self._pytorch_transforms: - transform.apply(self.model) - self.is_transformed = True - - def execute(self, *args, **kwargs): # type: ignore - raise NotImplementedError("Reached too far!!") - - def export(self) -> str: + def export(self, export_dir: Optional[str] = None) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." @@ -309,18 +193,20 @@ def export(self) -> str: Returns: :str: Path of the generated ``ONNX`` graph. """ - assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object" - # Export - _, onnx_model_path = QEfficient.export( - model_name=self.model_card_name, - model_kv=self, - tokenizer=self.tokenizer, - full_batch_size=self.full_batch_size, - num_speculative_tokens=self.num_speculative_tokens, + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + if self.num_speculative_tokens: + num_logits_to_keep = self.num_speculative_tokens+1 + setattr(self.model, "num_logits_to_keep", num_logits_to_keep) + if seq_len < num_logits_to_keep: + seq_len *= math.ceil((num_logits_to_keep) / seq_len) + fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - "position_ids": torch.arange(seq_len, dtype=torch.int64).view(bs, seq_len), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs,1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { @@ -393,19 +279,28 @@ def compile( :str: Path of the compiled ``qpc`` package. """ # Specializations + decode_seq_len = self.num_speculative_tokens+1 if self.num_speculative_tokens else 1 if self.continuous_batching: if full_batch_size is None: raise TypeError("missing required argument: 'full_batch_size'") specializations = [ {"full_batch_size": full_batch_size, "batch_size": 1, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 1, "ctx_len": ctx_len}, + {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}, ] + if self.is_dlm: + specializations.append( + {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 2, "ctx_len": ctx_len}, + ) else: specializations = [ {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - {"batch_size": batch_size, "seq_len": 1, "ctx_len": ctx_len}, + {"batch_size": batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}, ] + if self.is_dlm: + specializations.append( + {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, + ) # Custom IO custom_io = {} @@ -429,90 +324,15 @@ def compile( **compiler_options, ) - # Compile - QEfficient.compile( - onnx_path=self.onnx_path, - qpc_path=os.path.dirname(qpc_dir_path), - num_cores=num_cores, - device_group=device_group, - aic_enable_depth_first=aic_enable_depth_first, - mos=mos, - batch_size=batch_size, - prompt_len=prompt_len, - ctx_len=ctx_len, - mxfp6=mxfp6, - mxint8=mxint8, - full_batch_size=self.full_batch_size, - num_speculative_tokens=self.num_speculative_tokens, - is_dlm=self.is_dlm, - ) - self.qpc_path = qpc_dir_path - return self.qpc_path - - def export_and_compile( + # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate + def generate( self, - num_cores: int, - device_group: List[int], - batch_size: int = 1, - prompt_len: int = 32, - ctx_len: int = 128, - mxfp6: bool = True, - mxint8: bool = False, - mos: int = -1, - aic_enable_depth_first: bool = False, - qpc_dir_suffix: Optional[str] = None, - full_batch_size: Optional[int] = None, - ) -> str: - """ - This API is specific to Internal VLLM use-case and is not recommended to be used in your application unless your are using VLLM. - """ - _, transformed = CBTransform.apply(self.model) - if not transformed: - raise RuntimeError("Could not apply Continuous batch transform on the model") - if full_batch_size is not None: - self.full_batch_size = full_batch_size - - self.export() - - qpc_base_dir_name = get_qpc_dir_path( - model_card_name=self.model_card_name, - num_cores=num_cores, - mos=mos, - batch_size=batch_size, - prompt_len=prompt_len, - ctx_len=ctx_len, - mxfp6=mxfp6, - mxint8=mxint8, - device_group=device_group, - full_batch_size=self.full_batch_size, - ) - qpc_base_dir_name = ( - os.path.dirname(qpc_base_dir_name) + "_" + qpc_dir_suffix if qpc_dir_suffix else qpc_base_dir_name - ) - model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) - os.makedirs(model_card_dir, exist_ok=True) - qpc_dir_path = os.path.join(model_card_dir, qpc_base_dir_name) - - # Compile - self.qpc_path = QEfficient.compile( - onnx_path=self.onnx_path, - qpc_path=qpc_dir_path, - num_cores=num_cores, - device_group=device_group, - aic_enable_depth_first=aic_enable_depth_first, - mos=mos, - batch_size=batch_size, - prompt_len=prompt_len, - ctx_len=ctx_len, - mxfp6=mxfp6, - mxint8=mxint8, - full_batch_size=full_batch_size, - num_speculative_tokens=self.num_speculative_tokens, - is_dlm=self.is_dlm, - ) - return self.qpc_path - - def generate(self, prompts: List[str], device_id: List[int] = None, runtime: str = "AI_100", **kwargs): + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_id: List[int] = None, + runtime: str = "AI_100", + **kwargs, + ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed. @@ -547,4 +367,4 @@ def export(self): raise NotImplementedError("Reached too far!!") def compile(self, *args, **kwargs) -> Any: - raise NotImplementedError("Reached too far!!") + raise NotImplementedError("Reached too far!!") \ No newline at end of file diff --git a/tests/spd/test_tlm_dlm_export_and_compile.py b/tests/spd/test_tlm_dlm_export_and_compile.py index 2a049365a..cec269ffa 100644 --- a/tests/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/spd/test_tlm_dlm_export_and_compile.py @@ -18,52 +18,54 @@ pytest.param( [0], # device_group 2, # num_speculative_tokens - 32, # prompt_len + 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz 8, # full_batch_size "JackFram/llama-68m", # model_name + True, # continuous_batching id="CB llama", ), pytest.param( [0], # device_group 2, # num_speculative_tokens - 32, # prompt_len + 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz None, # full_batch_size "JackFram/llama-68m", # model_name + False, # continuous_batching id="non-CB llama", ), ] @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs ) def test_llama_tlm_logit_dims( device_group: List[int], num_speculative_tokens: int, - prompt_len: int, + prefill_seq_len: int, ctx_len: int, prefill_bsz: int, full_batch_size: Optional[int], model_name: str, + continuous_batching: bool, ): # get vocab size tokenizer = AutoTokenizer.from_pretrained(model_name) vocab_size = len(tokenizer) - # export_and_compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, num_speculative_tokens=num_speculative_tokens) - qpc_path: str = qeff_model.export_and_compile( + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), num_cores=16, - device_group=device_group, batch_size=prefill_bsz, - prompt_len=prompt_len, + prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - mxfp6=True, -# mxint8=True, + mxfp6_matmul=True, full_batch_size=full_batch_size, ) @@ -74,8 +76,8 @@ def test_llama_tlm_logit_dims( session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) # prefill dummy inputs prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prompt_len), dtype=np.int64), - position_ids=np.arange(prompt_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), ) # decode dummy inputs num_logits_to_keep = num_speculative_tokens + 1 @@ -102,31 +104,31 @@ def test_llama_tlm_logit_dims( @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs ) def test_llama_dlm_logit_dims( device_group: List[int], num_speculative_tokens: int, - prompt_len: int, + prefill_seq_len: int, ctx_len: int, prefill_bsz: int, full_batch_size: Optional[int], model_name: str, + continuous_batching: bool, ): # get vocab size tokenizer = AutoTokenizer.from_pretrained(model_name) vocab_size = len(tokenizer) - # export_and_compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, is_dlm=True) - qpc_path: str = qeff_model.export_and_compile( + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, is_dlm=True) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), num_cores=16, - device_group=device_group, batch_size=prefill_bsz, - prompt_len=prompt_len, + prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - mxfp6=True, -# mxint8=True, + mxfp6_matmul=True, full_batch_size=full_batch_size, ) @@ -137,8 +139,8 @@ def test_llama_dlm_logit_dims( session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) # prefill dummy inputs prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prompt_len), dtype=np.int64), - position_ids=np.arange(prompt_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1, 1), ) # decode-1 dummy inputs