From e9332e176ed2cc0ce19ed2449358ee05fd97b43f Mon Sep 17 00:00:00 2001 From: Isaac Wasserman Date: Thu, 19 Sep 2024 15:24:58 -0400 Subject: [PATCH] added LangChain_Chat module --- pyproject.toml | 1 + src/vanna/flask/__init__.py | 21 +++++++- src/vanna/langchain/__init__.py | 1 + src/vanna/langchain/langchain_chat.py | 69 +++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 src/vanna/langchain/__init__.py create mode 100644 src/vanna/langchain/langchain_chat.py diff --git a/pyproject.toml b/pyproject.toml index 25609a00..f2a0ab06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,3 +53,4 @@ milvus = ["pymilvus[model]"] bedrock = ["boto3", "botocore"] weaviate = ["weaviate-client"] azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"] +langchain = ["langchain>=0.3.0"] diff --git a/src/vanna/flask/__init__.py b/src/vanna/flask/__init__.py index 396f401a..56569331 100644 --- a/src/vanna/flask/__init__.py +++ b/src/vanna/flask/__init__.py @@ -1,3 +1,4 @@ +import importlib.metadata import json import logging import os @@ -5,13 +6,13 @@ import uuid from abc import ABC, abstractmethod from functools import wraps -import importlib.metadata import flask import requests from flasgger import Swagger from flask import Flask, Response, jsonify, request, send_from_directory from flask_sock import Sock +from langchain_core.messages import BaseMessage from ..base import VannaBase from .assets import css_content, html_content, js_content @@ -190,7 +191,23 @@ def __init__( if self.debug: def log(message, title="Info"): - [ws.send(json.dumps({'message': message, 'title': title})) for ws in self.ws_clients] + if ( + isinstance(message, list) + and len(message) > 0 + and isinstance(message[0], BaseMessage) + ): + message = [dict(m) for m in message] + [ + ws.send( + json.dumps( + { + "message": message, + "title": title, + } + ) + ) + for ws in self.ws_clients + ] self.vn.log = log diff --git a/src/vanna/langchain/__init__.py b/src/vanna/langchain/__init__.py new file mode 100644 index 00000000..04110005 --- /dev/null +++ b/src/vanna/langchain/__init__.py @@ -0,0 +1 @@ +from .langchain_chat import LangChain_Chat diff --git a/src/vanna/langchain/langchain_chat.py b/src/vanna/langchain/langchain_chat.py new file mode 100644 index 00000000..cebe4a6b --- /dev/null +++ b/src/vanna/langchain/langchain_chat.py @@ -0,0 +1,69 @@ +from typing import List + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) + +from ..base import VannaBase + + +class LangChain_Chat(VannaBase): + def __init__(self, chat_model: BaseChatModel, config=None): + VannaBase.__init__(self, config=config) + self.llm = chat_model + self.model_name = ( + self.llm.model_name + if hasattr(self.llm, "model_name") + else type(self.llm).__name__ + ) + + def system_message(self, message: str) -> any: + return SystemMessage(message) + + def user_message(self, message: str) -> any: + return HumanMessage(message) + + def assistant_message(self, message: str) -> any: + return AIMessage(message) + + def count_prompt_tokens( + self, input_messages: List[BaseMessage], output_message: AIMessage + ) -> int: + # OpenAI + if ( + "token_usage" in output_message.response_metadata + and "prompt_tokens" in output_message.response_metadata["token_usage"] + ): + return output_message.response_metadata["token_usage"]["prompt_tokens"] + # Anthropic + elif ( + "usage" in output_message.response_metadata + and "input_tokens" in output_message.response_metadata["usage"] + ): + return output_message.response_metadata["usage"]["input_tokens"] + # Other + else: + num_tokens = 0 + for message in input_messages: + num_tokens += len(message.content) / 4 + return num_tokens + + def submit_prompt(self, prompt: List[BaseMessage], **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + response = self.llm.invoke(prompt) + num_tokens = self.count_prompt_tokens(prompt, response) + self.log( + f"Used model {self.model_name} for {num_tokens} tokens (approx)", + title="Model Used", + ) + + return response.content