Skip to content

Commit

Permalink
fix(agent): langchain LLM instantiation (#977)
Browse files Browse the repository at this point in the history
* fix(agent): langchain LLM instantiation

* replace object with mock
  • Loading branch information
mspronesti authored Mar 1, 2024
1 parent 6587536 commit 8c79f9b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
6 changes: 3 additions & 3 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..helpers.logger import Logger
from ..helpers.memory import Memory
from ..llm.base import LLM
from ..llm.langchain import LangchainLLM
from ..llm.langchain import LangchainLLM, is_langchain_llm
from ..pipelines.chat.generate_chat_pipeline import (
GenerateChatPipeline,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_config(self, config: Union[Config, dict]):

config = load_config_from_json(config)

if isinstance(config, dict) and config.get("llm") is None:
if isinstance(config, dict) and config.get("llm") is not None:
config["llm"] = self.get_llm(config["llm"])

config = Config(**config)
Expand All @@ -161,7 +161,7 @@ def get_llm(self, llm: LLM) -> LLM:
BadImportError: If the LLM is a Langchain LLM but the langchain package
is not installed
"""
if LangchainLLM.is_langchain_llm(llm):
if is_langchain_llm(llm):
llm = LangchainLLM(llm)

return llm
Expand Down
2 changes: 0 additions & 2 deletions pandasai/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
"LLM",
"AzureOpenAI",
"OpenAI",
"Falcon",
"GoogleGemini",
"GooglePalm",
"GoogleVertexAI",
"HuggingFaceTextGen",
Expand Down
22 changes: 11 additions & 11 deletions pandasai/llm/langchain.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

try:
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

LANGCHAIN_AVAILABLE = True
except ImportError:
from unittest.mock import Mock

# Fallback definitions if langchain_core is not installed
BaseLLM = BaseChatModel = object
LANGCHAIN_AVAILABLE = False
BaseLanguageModel = BaseChatModel = Mock

from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from pandasai.prompts.base import BasePrompt

Expand All @@ -29,15 +29,19 @@
"""


def is_langchain_llm(llm) -> bool:
return isinstance(llm, BaseLanguageModel)


class LangchainLLM(LLM):
"""
Class to wrap Langchain LLMs and make PandasAI interoperable
with LangChain.
"""

langchain_llm = None
langchain_llm: BaseLanguageModel

def __init__(self, langchain_llm: Union[BaseLLM, BaseChatModel]):
def __init__(self, langchain_llm: BaseLanguageModel):
self.langchain_llm = langchain_llm

def call(
Expand All @@ -51,10 +55,6 @@ def call(
res = self.langchain_llm.invoke(prompt)
return res.content if isinstance(self.langchain_llm, BaseChatModel) else res

@staticmethod
def is_langchain_llm(llm: LLM) -> bool:
return hasattr(llm, "_llm_type")

@property
def type(self) -> str:
return f"langchain_{self.langchain_llm._llm_type}"

0 comments on commit 8c79f9b

Please sign in to comment.