From 4328d32e0c4d0effd51af89e563664781a8b953b Mon Sep 17 00:00:00 2001 From: Timothy Wang Date: Tue, 15 Oct 2024 12:43:54 -0400 Subject: [PATCH] Look for language model lm head --- server/lorax_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 99821bed7..2626d30b8 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -838,7 +838,7 @@ def __init__( weights._set_config(model_id, config) self._supports_embeddings = embedding_dim is not None - if not weights.has_tensor("lm_head.weight") and not self._supports_embeddings: + if not (weights.has_tensor("lm_head.weight") or weights.has_tensor("language_model.lm_head.weight")) and not self._supports_embeddings: raise ValueError( "Model does not have lm head so it is presumed to be for embeddings." "No embedding_dim was provided so we cannot load the model."