Skip to content

Commit

Permalink
Feat/support langchain 0 3 (#69)
Browse files Browse the repository at this point in the history
* feat: Add langchain 0.3 support
  • Loading branch information
valeriosofi authored Oct 7, 2024
1 parent 3682361 commit 4e58c1c
Show file tree
Hide file tree
Showing 4 changed files with 2,054 additions and 1,746 deletions.
19 changes: 19 additions & 0 deletions nebuly/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Event(abc.ABC):
module: str
start_time: datetime
end_time: datetime | None = None
input_llm: str | None = None

@property
@abc.abstractmethod
Expand Down Expand Up @@ -118,3 +119,21 @@ def delete_events(self, root_id: UUID) -> None:

for key in keys_to_delete:
self.events.pop(key)

def update_chain_input(self, prompt: str) -> None:
"""
This method overrides the input of the chain event
"""
chain_events = [
event for event in self.events.values() if event.data.type.value == "chain"
]
if len(chain_events) == 1:
chain_input = chain_events[0].input

if (
isinstance(chain_input, str)
and len(chain_input) > 0
and chain_input in prompt
and chain_input != prompt
):
chain_events[0].input_llm = prompt
29 changes: 25 additions & 4 deletions nebuly/providers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,12 @@ class EventType(Enum):

@dataclass
class LangChainEvent(Event):
input_llm = None

@property
def input(self) -> str:
def input(self) -> str: # pylint: disable=too-many-return-statements
if self.input_llm is not None:
return self.input_llm
if self.data.kwargs is None:
raise ValueError("Event has no kwargs.")
if self.data.type is EventType.CHAT_MODEL:
Expand All @@ -221,6 +225,8 @@ def input(self) -> str:
raise ValueError("Event has no inputs.")
try:
chain = load(self.data.kwargs.get("serialized", {}))
if chain is None:
return _parse_langchain_data(inputs)
if isinstance(chain, RunnableSequence):
return _get_input_and_history_runnable_seq(chain, inputs).prompt
return _get_input_and_history(chain, inputs).prompt
Expand All @@ -240,6 +246,8 @@ def output(self) -> str:
raise ValueError("Event has no kwargs.")
try:
chain = load(self.data.kwargs.get("serialized", {}))
if chain is None:
return _parse_langchain_data(self.data.output)
if isinstance(chain, RunnableSequence):
return _parse_output(self.data.output)
return _get_output_chain(chain, self.data.output)
Expand All @@ -249,7 +257,9 @@ def output(self) -> str:
raise ValueError(f"Event type {self.data.type} not supported.")

@property
def history(self) -> list[HistoryEntry]:
def history( # pylint: disable=too-many-return-statements
self,
) -> list[HistoryEntry]:
if self.data.kwargs is None:
raise ValueError("Event has no kwargs.")

Expand Down Expand Up @@ -283,6 +293,8 @@ def history(self) -> list[HistoryEntry]:
raise ValueError("Event has no inputs.")
try:
chain = load(self.data.kwargs.get("serialized", {}))
if chain is None:
return _parse_langchain_history(inputs)
if isinstance(chain, RunnableSequence):
return _get_input_and_history_runnable_seq(chain, inputs).history
return _get_input_and_history(chain, inputs).history
Expand Down Expand Up @@ -317,7 +329,10 @@ def _get_function(self) -> str:
raise ValueError("Event has no kwargs.")
if self.data.type is EventType.TOOL:
return self.data.kwargs["serialized"]["name"] # type: ignore
return ".".join(self.data.kwargs["serialized"]["id"])
try:
return ".".join(self.data.kwargs["serialized"]["id"])
except TypeError:
return "unknown"

def _get_rag_source(self) -> str | None:
if self.data.kwargs is None or len(self.data.kwargs) == 0:
Expand Down Expand Up @@ -465,6 +480,10 @@ def on_llm_start( # pylint: disable=arguments-differ
) -> None:
if self.verbose:
logger.info("LLM model started with %d prompts", len(prompts))

if len(prompts) == 1:
self._events_storage.update_chain_input(prompts[0])

data = EventData(
type=EventType.LLM_MODEL,
kwargs={
Expand Down Expand Up @@ -558,7 +577,9 @@ def on_chain_end( # pylint: disable=arguments-differ
self._events_storage.delete_events(run_id)

def send_pending_interaction(self, output: dict[str, str]) -> None:
for run_id in self._events_storage.events: # pylint: disable=consider-using-dict-items
for (
run_id
) in self._events_storage.events: # pylint: disable=consider-using-dict-items
if self._events_storage.events[run_id].hierarchy is None:
self._events_storage.events[run_id].data.add_end_event_data(
kwargs={}, output=output
Expand Down
Loading

0 comments on commit 4e58c1c

Please sign in to comment.