Skip to content

Commit

Permalink
Upgrades LangChain to 0.0.277
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Sep 1, 2023
1 parent 1556a02 commit 7833fea
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra
from langchain.pydantic_v1 import BaseModel, Extra


class BaseEmbeddingsProvider(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal, Optional, get_args

import click
from pydantic import BaseModel
from langchain.pydantic_v1 import BaseModel

FORMAT_CHOICES_TYPE = Literal[
"code", "html", "image", "json", "markdown", "math", "md", "text"
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
from pydantic import BaseModel, Extra, root_validator


class EnvAuthStrategy(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"ipython",
"pydantic",
"importlib_metadata>=5.2.0",
"langchain==0.0.223",
"langchain==0.0.277",
"typing_extensions>=4.5.0",
"click~=8.0",
"jsonpath-ng>=1.5.3,<2",
Expand Down
19 changes: 15 additions & 4 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Awaitable, Coroutine, List, Optional, Tuple

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.document_loaders.directory import get_embeddings, split
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
from jupyter_ai.models import (
Expand All @@ -29,7 +30,7 @@
METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json")


class LearnChatHandler(BaseChatHandler, BaseRetriever):
class LearnChatHandler(BaseChatHandler):
def __init__(
self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs
):
Expand Down Expand Up @@ -266,9 +267,6 @@ def load_metadata(self):
j = json.loads(f.read())
self.metadata = IndexMetadata(**j)

def get_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError()

async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
Expand All @@ -291,3 +289,16 @@ def get_embedding_model(self):
return None

return em_provider_cls(**em_provider_args)


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None

def _get_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
return docs
6 changes: 3 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time

from dask.distributed import Client as DaskClient
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp

Expand Down Expand Up @@ -93,9 +94,8 @@ def initialize_settings(self):
dask_client_future=dask_client_future,
)
help_chat_handler = HelpChatHandler(**chat_handler_kwargs)
ask_chat_handler = AskChatHandler(
**chat_handler_kwargs, retriever=learn_chat_handler
)
retriever = Retriever(learn_chat_handler=learn_chat_handler)
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
self.settings["jai_chat_handlers"] = {
"default": default_chat_handler,
"/ask": ask_chat_handler,
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jupyter_ai.chat_handlers import BaseChatHandler
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
from jupyter_server.base.handlers import JupyterHandler
from pydantic import ValidationError
from langchain.pydantic_v1 import ValidationError
from tornado import web, websocket
from tornado.web import HTTPError

Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics.providers import AuthStrategy, Field
from pydantic import BaseModel
from langchain.pydantic_v1 import BaseModel

DEFAULT_CHUNK_SIZE = 2000
DEFAULT_CHUNK_OVERLAP = 100
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"openai~=0.26",
"aiosqlite>=0.18",
"importlib_metadata>=5.2.0",
"langchain==0.0.223",
"langchain==0.0.277",
"tiktoken", # required for OpenAIEmbeddings
"jupyter_ai_magics",
"dask[distributed]",
Expand Down

0 comments on commit 7833fea

Please sign in to comment.