diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index 2aca92c7d..904ec3366 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -501,7 +501,7 @@ def forward( class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, causal=True): super().__init__() self.config = config diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ee95a327a..e7e295fd7 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1114,15 +1114,6 @@ def __init__( weights._set_config(model_id, config) self._supports_embeddings = embedding_dim is not None - if ( - not (weights.has_tensor("lm_head.weight") or weights.has_tensor("language_model.lm_head.weight")) - and not self._supports_embeddings - ): - raise ValueError( - "Model does not have lm head so it is presumed to be for embeddings." - "No embedding_dim was provided so we cannot load the model." - "Please pass in an embedding_dim to the model." - ) prefix = "" model = model_cls(prefix, config, weights) @@ -1750,7 +1741,8 @@ def generate_token( # Only save tokens if we are done prefilling for this request batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]