Skip to content

Commit

Permalink
Merge branch 'main' into lm_eval_llama
Browse files Browse the repository at this point in the history
  • Loading branch information
sywangyi authored Dec 13, 2024
2 parents 00f269b + 4e56c47 commit 0f1ee4e
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
from ..modeling_all_models import Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig


Expand Down Expand Up @@ -385,7 +385,23 @@ def forward(
)


class LlamaKVCache(KVCache):

class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
if inp_seq_len != -1:
Expand All @@ -405,6 +421,13 @@ def update(prev, cur, dim, idx, inp_seq_len):
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)

def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
Expand All @@ -419,8 +442,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = LlamaKVCache()
self.v_cache = LlamaKVCache()
self.k_cache = KVCache()
self.v_cache = KVCache()

if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
Expand Down

0 comments on commit 0f1ee4e

Please sign in to comment.