Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shard input_mask for Llama #905

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

stbaione
Copy link
Contributor

@stbaione stbaione commented Feb 3, 2025

Currently, 405b OOMs when using long input prompts (issue here).

This PR implements a suggested fix for this, which is to shard the input_mask. This makes sense, since the OOM error is dependent on the length of the input.

I'm new to sharktank and still seeing the issue with current implementation, so wanted to have it double-checked to make sure it's properly implemented.

MLIR here

Copy link
Contributor

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that seems unaddressed is the scaled_dot_product_attention. It also applies the mask. How would that be reconciled? If is_causal=False and attention_mask=None does it not apply any mask?

@@ -129,20 +129,25 @@ def prefill(
# [bs, batch_seq_len]
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [bs, batch_seq_len]
input_mask: Union[torch.Tensor, ReplicatedTensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This input needs a check such that it should be mutually exclusive with the attention_mask arg.

@@ -166,6 +171,8 @@ def decode(
# [bs, 1]
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [bs, 1]
input_mask: Union[torch.Tensor, ReplicatedTensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This input needs a check such that it should be mutually exclusive with the attention_mask arg.

self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

h *= input_mask.unsqueeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that if we assume this should somehow substitute the attention mask the math checks out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants