Skip to content

Commit

Permalink
Loads vector store index lazily (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins authored Sep 1, 2023
1 parent 93eb896 commit 1556a02
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import json
import os
from typing import Any, Awaitable, Coroutine, List
from typing import Any, Awaitable, Coroutine, List, Optional, Tuple

from dask.distributed import Client as DaskClient
from jupyter_ai.document_loaders.directory import get_embeddings, split
Expand Down Expand Up @@ -59,10 +59,10 @@ def __init__(
if not os.path.exists(INDEX_SAVE_DIR):
os.makedirs(INDEX_SAVE_DIR)

self._load_or_create()
self._load()

def _load_or_create(self):
"""Loads the vector store and creates a new one if none exists."""
def _load(self):
"""Loads the vector store."""
embeddings = self.get_embedding_model()
if not embeddings:
return
Expand All @@ -73,14 +73,12 @@ def _load_or_create(self):
)
self.load_metadata()
except Exception as e:
self.create()
self.log.error("Could not load vector index from disk.")

async def _process_message(self, message: HumanChatMessage):
if not self.index:
self._load_or_create()

# If index is not still there, embeddings are not present
if not self.index:
# If no embedding provider has been selected
em_provider_cls, em_provider_args = self.get_embedding_provider()
if not em_provider_cls:
self.reply(
"Sorry, please select an embedding provider before using the `/learn` command."
)
Expand Down Expand Up @@ -153,7 +151,11 @@ async def learn_dir(self, path: str, chunk_size: int, chunk_overlap: int):
em_provider_cls, em_provider_args = self.get_embedding_provider()
delayed = get_embeddings(doc_chunks, em_provider_cls, em_provider_args)
embedding_records = await dask_client.compute(delayed)
self.index.add_embeddings(*embedding_records)
if self.index:
self.index.add_embeddings(*embedding_records)
else:
self.create(*embedding_records)

self._add_dir_to_metadata(path, chunk_size, chunk_overlap)
self.prev_em_id = em_provider_cls.id + ":" + em_provider_args["model_id"]

Expand Down Expand Up @@ -212,7 +214,6 @@ def delete(self):
for path in paths:
if os.path.isfile(path):
os.remove(path)
self.create()

async def relearn(self, metadata: IndexMetadata):
# Index all dirs in the metadata
Expand All @@ -234,15 +235,16 @@ async def relearn(self, metadata: IndexMetadata):
You can ask questions about these docs by prefixing your message with **/ask**."""
self.reply(message)

def create(self):
def create(
self,
embedding_records: List[Tuple[str, List[float]]],
metadatas: Optional[List[dict]] = None,
):
embeddings = self.get_embedding_model()
if not embeddings:
return
self.index = FAISS.from_texts(
[
"Jupyternaut knows about your filesystem, to ask questions first use the /learn command."
],
embeddings,
self.index = FAISS.from_embeddings(
text_embeddings=embedding_records, embedding=embeddings, metadatas=metadatas
)
self.save()

Expand Down

0 comments on commit 1556a02

Please sign in to comment.