From 231e5bbcc5e1ea5f253002488587eacacd8f5e55 Mon Sep 17 00:00:00 2001 From: q yao Date: Wed, 9 Oct 2024 19:23:14 +0800 Subject: [PATCH] Fix llama3.2-1b inference error by handling tie_word_embedding (#2568) --- lmdeploy/pytorch/models/llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 2641429683..a933e60825 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -371,6 +371,11 @@ def forward( ) return hidden_states + def update_weights(self): + """update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" return self.lm_head(hidden_states)