From 911e157c54fe87bb1f235d5619d76429d1acecf9 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 16 Jan 2025 11:46:29 +0000 Subject: [PATCH] [llama] Generate causal mask better --- sharktank/sharktank/layers/causal_llm.py | 22 +++++----------------- sharktank/sharktank/models/llama/llama.py | 1 - 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 8ace77981..3ce88cdf8 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -42,13 +42,6 @@ def __init__( self.context_length = context_length self.fake_quant = fake_quant - if static_tables: - self.register_buffer( - "causal_context_mask", self.generate_causal_context_mask() - ) - else: - self.causal_context_mask = None - def _assert_device(self, *ts: torch.Tensor, dtype: Optional[torch.dtype] = None): if self.device is not None: for t in ts: @@ -67,11 +60,10 @@ def _maximally_negative_value(self, dtype): """ return float("-inf") - def generate_causal_context_mask(self) -> torch.Tensor: - context_length = self.context_length + def generate_causal_context_mask(self, batch_seqlen: int) -> torch.Tensor: unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device) context_broadcast_ones = unary_broadcast_ones.expand( - context_length, context_length + batch_seqlen, batch_seqlen ) causal_context_mask = torch.triu( context_broadcast_ones, @@ -117,18 +109,14 @@ def attention_mask( Since this is a bool tensor of context_length^2, different deployment scenarios can benefit from managing this in different ways. """ - if causal_context_mask is None: - # Try to use the statically generated. - causal_context_mask = self.causal_context_mask + _, batch_seqlen = input_mask.shape if causal_context_mask is None: # Fallback to dynamically generated. - causal_context_mask = self.generate_causal_context_mask() + causal_context_mask = self.generate_causal_context_mask(batch_seqlen) # Combine the causal context mask and input mask. dtype = self.attention_dtype - _, batch_seq_len = input_mask.shape - causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] - boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) + boolean_mask = torch.logical_or(causal_context_mask, input_mask[:, None, None, :]) numeric_mask = torch.where( boolean_mask, self._maximally_negative_value(dtype), 0 ).to(dtype) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 0a9a6f1c3..676a77a6e 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, - static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype,