Skip to content

Commit

Permalink
Minor Bug Fixes - LLaMa Embedding (#12146)
Browse files Browse the repository at this point in the history
* Minor Bug Fixes - LLaMa Embedding

Signed-off-by: Sam Oluwalana <[email protected]>

* Apply isort and black reformatting

Signed-off-by: artbataev <[email protected]>

* Add type checking

Signed-off-by: Sam Oluwalana <[email protected]>

---------

Signed-off-by: Sam Oluwalana <[email protected]>
Signed-off-by: artbataev <[email protected]>
Co-authored-by: artbataev <[email protected]>
  • Loading branch information
soluwalana and artbataev authored Feb 12, 2025
1 parent d977f42 commit a682ea9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/hf_llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(

loss = None
if labels is not None:
labels = labels.to(logits.device)
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/llm/gpt/model/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io
from nemo.lightning.pytorch.utils import dtype_from_hf
from nemo.utils import logging
from nemo.utils.import_utils import safe_import

if TYPE_CHECKING:
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalModel

_, HAVE_TE = safe_import("transformer_engine")


Expand Down Expand Up @@ -271,7 +274,7 @@ class LlamaEmbeddingExporter(io.ModelConnector[LlamaEmbeddingModel, "LlamaBidire
Note that NV Embedding LLama uses customized LlamaBidirectionalConfig config.
"""

def init(self, dtype=torch.bfloat16) -> "LlamaForCausalLM":
def init(self, dtype=torch.bfloat16) -> "LlamaBidirectionalModel":
from transformers.modeling_utils import no_init_weights

from nemo.collections.llm.gpt.model.hf_llama_embedding import LlamaBidirectionalModel
Expand Down

0 comments on commit a682ea9

Please sign in to comment.