Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Generate causal mask better #832

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[llama] Generate causal mask better
Groverkss committed Jan 16, 2025
commit 911e157c54fe87bb1f235d5619d76429d1acecf9
22 changes: 5 additions & 17 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
@@ -42,13 +42,6 @@ def __init__(
self.context_length = context_length
self.fake_quant = fake_quant

if static_tables:
self.register_buffer(
"causal_context_mask", self.generate_causal_context_mask()
)
else:
self.causal_context_mask = None

def _assert_device(self, *ts: torch.Tensor, dtype: Optional[torch.dtype] = None):
if self.device is not None:
for t in ts:
@@ -67,11 +60,10 @@ def _maximally_negative_value(self, dtype):
"""
return float("-inf")

def generate_causal_context_mask(self) -> torch.Tensor:
context_length = self.context_length
def generate_causal_context_mask(self, batch_seqlen: int) -> torch.Tensor:
unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device)
context_broadcast_ones = unary_broadcast_ones.expand(
context_length, context_length
batch_seqlen, batch_seqlen
)
causal_context_mask = torch.triu(
context_broadcast_ones,
@@ -117,18 +109,14 @@ def attention_mask(
Since this is a bool tensor of context_length^2, different deployment
scenarios can benefit from managing this in different ways.
"""
if causal_context_mask is None:
# Try to use the statically generated.
causal_context_mask = self.causal_context_mask
_, batch_seqlen = input_mask.shape
if causal_context_mask is None:
# Fallback to dynamically generated.
causal_context_mask = self.generate_causal_context_mask()
causal_context_mask = self.generate_causal_context_mask(batch_seqlen)

# Combine the causal context mask and input mask.
dtype = self.attention_dtype
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
boolean_mask = torch.logical_or(causal_context_mask, input_mask[:, None, None, :])
numeric_mask = torch.where(
boolean_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
1 change: 0 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
@@ -67,7 +67,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,