Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,29 @@ def _init_chat_model_helper(

return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
try:
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_huggingface.llms import HuggingFacePipeline
except ImportError as e:
import_error_msg = "Please install langchain-huggingface to use HuggingFace models."
raise ImportError(import_error_msg) from e

# The 'task' kwarg is required by from_model_id but not the base constructor.
# We pop it from kwargs to avoid the Pydantic 'extra_forbidden' error.
task = kwargs.pop("task", None)
if not task:
task_error_msg = "The 'task' keyword argument is required for HuggingFace models."
raise ValueError(task_error_msg)

# Initialize the base LLM pipeline with the model and arguments
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task=task,
**kwargs, # Pass remaining kwargs like `device`
)

return ChatHuggingFace(model_id=model, **kwargs)
# Pass the initialized LLM to the chat wrapper
return ChatHuggingFace(llm=llm)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
Expand Down Expand Up @@ -957,4 +976,4 @@ def with_structured_output(
schema: Union[dict, type[BaseModel]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)
return self.__getattr__("with_structured_output")(schema, **kwargs)
14 changes: 14 additions & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langchain_huggingface.chat_models import ChatHuggingFace
from pydantic import SecretStr

from langchain.chat_models.base import __all__, init_chat_model
Expand Down Expand Up @@ -289,3 +290,16 @@ def test_configurable_with_default() -> None:
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)


def test_init_chat_model_huggingface() -> None:
"""Test that init_chat_model works with huggingface."""
model_name = "google-bert/bert-base-uncased"

llm = init_chat_model(
model=model_name,
model_provider="huggingface",
task="text-generation",
)
assert isinstance(llm, ChatHuggingFace)
assert llm.llm.model_id == model_name