diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index de933e379..ffd1507ab 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: transformers-version: ["4.46.0", "4.46.3"] - torch-version: ["2.4.0", "2.5.*"] + torch-version: ["2.5.*"] runs-on: ubuntu-22.04 diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index dec1e8189..e1f6aa19b 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -44,7 +44,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 16 + self.block_size = 64 self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 03937754a..8c5ef5030 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,6 +27,7 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _IPEXGPT2MLP, _falcon_model_forward, _gpt2_block_forward, _gpt2_model_forward, @@ -111,6 +112,7 @@ def _patch_gpt2_model(model): convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ca51c47fb..aa558c437 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,6 +19,9 @@ import torch from torch import nn from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from optimum.intel.utils.import_utils import is_ipex_version @@ -29,7 +32,7 @@ logger = logging.getLogger(__name__) -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -37,12 +40,13 @@ f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: - from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding from intel_extension_for_pytorch.llm.modules import ( Linear2SiluMul, LinearAdd, LinearAddAdd, LinearGelu, + LinearNewGelu, PagedAttention, ) @@ -194,7 +198,10 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -207,7 +214,13 @@ def _llama_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -309,7 +322,9 @@ def _falcon_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -321,7 +336,14 @@ def _falcon_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) next_decoder_cache = None all_self_attentions = () if output_attentions else None @@ -436,7 +458,9 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) - if past_length == 0: + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -444,7 +468,13 @@ def _gpt2_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) presents = None all_self_attentions = () if output_attentions else None @@ -528,7 +558,10 @@ def _gpt2_block_forward( attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + residual + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention @@ -557,7 +590,10 @@ def _gpt2_block_forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = residual + feed_forward_hidden_states + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -577,6 +613,7 @@ def __init__(self, module, config) -> None: self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device ).repeat_interleave(self.num_groups) + self.use_sdpa = False def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -585,9 +622,32 @@ def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") def postprocess_attention_output(self, attn_output): + if self.use_sdpa: + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output + def varlen_attn(self, query, key, value, past_key_value, input_lens): + # prefill, remove padding + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + PagedAttention.flash_attn_varlen_func( + attn_output, + query, + key, + value, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + + return attn_output + def forward( self, hidden_states: torch.Tensor, @@ -610,28 +670,28 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) - attn_output = torch.empty_like(query) if past_len == 0: - # prefill, remove padding - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - attn_output, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 0.0, - 1.0 / math.sqrt(self.head_dim), - False, - True, - False, - None, - ) + # prefill + if past_key_value is None: + n_rep = query.shape[1] // key.shape[1] + attn_output = torch.nn.functional.scaled_dot_product_attention( + query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=True, + ) + self.use_sdpa = True + else: + attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) else: # decode + attn_output = torch.empty_like(query) PagedAttention.single_query_cached_kv_attention( attn_output, query, @@ -720,9 +780,23 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) + _setattr_from_module(self, module) + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -732,9 +806,11 @@ def rope(self, query, key, *args, **kwargs): return query, key def postprocess_attention_output(self, attn_output): + if self.use_sdpa: + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) + if not hasattr(self, "linear_add"): + attn_output = self.c_proj(attn_output) return attn_output @@ -805,6 +881,40 @@ def forward( return output +class _IPEXGPT2MLP(nn.Module): + def __init__(self, module, config) -> None: + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.module_device = next(module.parameters()).device + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + if hasattr(self, "linear_new_gelu"): + hidden_states = self.linear_new_gelu(hidden_states) + else: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + if not hasattr(self, "linear_add"): + hidden_states = self.c_proj(hidden_states) + return hidden_states + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config): diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index af36d06f4..3263e31db 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -316,9 +316,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return self.model.prepare_inputs_for_generation(*args, **kwargs) def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): raise ValueError( - f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) # Patch functions to support ipex_paged cache if self._add_patch: diff --git a/setup.py b/setup.py index f78052b4b..bfcb3af7e 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "nncf": ["nncf>=2.14.0"], "openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"], "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"], - "ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47", "accelerate"], + "ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47", "accelerate"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE,