diff --git a/nemo/collections/llm/gpt/model/hf_llama_embedding.py b/nemo/collections/llm/gpt/model/hf_llama_embedding.py index ba89626ff45f..bbd27ce60507 100644 --- a/nemo/collections/llm/gpt/model/hf_llama_embedding.py +++ b/nemo/collections/llm/gpt/model/hf_llama_embedding.py @@ -156,7 +156,7 @@ def forward( loss = None if labels is not None: - labels = labels.to(logits.device) + labels = labels.to(pooled_logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" diff --git a/nemo/collections/llm/gpt/model/llama_embedding.py b/nemo/collections/llm/gpt/model/llama_embedding.py index 3d8edcc5121a..96f311acd0b8 100644 --- a/nemo/collections/llm/gpt/model/llama_embedding.py +++ b/nemo/collections/llm/gpt/model/llama_embedding.py @@ -31,12 +31,15 @@ from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io from nemo.lightning.pytorch.utils import dtype_from_hf +from nemo.utils import logging from nemo.utils.import_utils import safe_import if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalModel + _, HAVE_TE = safe_import("transformer_engine") @@ -271,7 +274,7 @@ class LlamaEmbeddingExporter(io.ModelConnector[LlamaEmbeddingModel, "LlamaBidire Note that NV Embedding LLama uses customized LlamaBidirectionalConfig config. """ - def init(self, dtype=torch.bfloat16) -> "LlamaForCausalLM": + def init(self, dtype=torch.bfloat16) -> "LlamaBidirectionalModel": from transformers.modeling_utils import no_init_weights from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalModel