Skip to content

Commit

Permalink
Remove bad check (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Nov 15, 2024
1 parent 7e8af35 commit e4394d6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def forward(


class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix, config, weights, causal=True):
super().__init__()
self.config = config

Expand Down
12 changes: 2 additions & 10 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,15 +1114,6 @@ def __init__(
weights._set_config(model_id, config)

self._supports_embeddings = embedding_dim is not None
if (
not (weights.has_tensor("lm_head.weight") or weights.has_tensor("language_model.lm_head.weight"))
and not self._supports_embeddings
):
raise ValueError(
"Model does not have lm head so it is presumed to be for embeddings."
"No embedding_dim was provided so we cannot load the model."
"Please pass in an embedding_dim to the model."
)

prefix = ""
model = model_cls(prefix, config, weights)
Expand Down Expand Up @@ -1750,7 +1741,8 @@ def generate_token(
# Only save tokens if we are done prefilling for this request
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
Expand Down

0 comments on commit e4394d6

Please sign in to comment.