Skip to content

Commit

Permalink
replace all legacy connected loaders with split ones
Browse files Browse the repository at this point in the history
  • Loading branch information
latekvo committed May 24, 2024
1 parent edb313b commit 4c543d5
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 35 deletions.
7 changes: 2 additions & 5 deletions core/databases/db_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter

from arguments import get_runtime_config
from core.tools.model_loader import load_model
from core.tools.utils import use_faiss, is_text_junk, remove_characters
from core.tools.utils import use_faiss, is_text_junk

runtime_configuration = get_runtime_config()
llm_config = runtime_configuration.llm_config
embedder_config = runtime_configuration.embedder_config

_, embedder = load_model()

vector_db = use_faiss("embeddings", embedder_config.model_name)
vector_db, embedder = use_faiss("embeddings", embedder_config.model_name)

text_splitter = RecursiveCharacterTextSplitter(
separators=embedder_config.buffer_stops,
Expand Down
5 changes: 3 additions & 2 deletions core/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from langchain_core.runnables import RunnableLambda

from arguments import get_runtime_config
from core.tools.model_loader import load_llm, load_embedder
from core.tools.utils import purify_name

from core.chainables.web import (
Expand All @@ -13,12 +14,12 @@
web_wiki_lookup_prompt,
)
from core.tools.dbops import get_vec_db_by_name
from core.tools.model_loader import load_model


output_parser = StrOutputParser()

llm, embeddings = load_model()
llm = load_llm()
embeddings = load_embedder()

runtime_configuration = get_runtime_config()
llm_config = runtime_configuration.llm_config
Expand Down
16 changes: 0 additions & 16 deletions core/tools/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def load_ollama_embedder():
return OllamaEmbeddings(model=embedder_config.model_name)


def load_ollama_model():
return load_ollama_llm(), load_ollama_embedder()


def load_hf_llm():
base_model_path = hf_hub_download(
llm_config.model_file, filename=llm_config.model_name
Expand All @@ -53,18 +49,6 @@ def load_hf_embedder():
)


def load_hugging_face_model():
return load_hf_llm(), load_hf_embedder()


def load_model():
# todo: split up into separate llm and embedder functions
if llm_config.supplier == "hugging_face":
return load_hugging_face_model()
else:
return load_ollama_model()


def load_llm():
if llm_config.supplier == "hugging_face":
return load_hf_llm()
Expand Down
7 changes: 5 additions & 2 deletions core/tools/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colorama import Fore, Style

from arguments import get_runtime_config
from core.tools.model_loader import load_model
from core.tools.model_loader import load_embedder, load_llm
from core.tools.utils import purify_name
from core.tools.dbops import get_vec_db_by_name
from core.classes.query import WebQuery
Expand All @@ -19,10 +19,13 @@
output_parser = StrOutputParser()

runtime_configuration = get_runtime_config()

llm_config = runtime_configuration.llm_config
embedder_config = runtime_configuration.embedder_config

llm, embeddings = load_model()
llm = load_llm()
embeddings = load_embedder()

embedding_model_safe_name = purify_name(embedder_config.model_name)


Expand Down
7 changes: 4 additions & 3 deletions core/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from core.databases import defaults
from core.tools.dbops import get_vec_db_by_name
from core.tools.model_loader import load_model
from core.tools.model_loader import load_embedder


def purify_name(name):
# fixme awful code lol
return "_".join("_".join(name.split(":")).split("-"))


Expand Down Expand Up @@ -117,12 +118,12 @@ def use_faiss(db_name, model_name):
if not os.path.exists(data_path):
os.makedirs(data_path)

_, embedder = load_model()
embedder = load_embedder()

db_full_name = gen_vec_db_full_name(db_name, model_name)
db = get_vec_db_by_name(db_full_name, embedder)

return db
return db, embedder


class hide_prints:
Expand Down
14 changes: 7 additions & 7 deletions workers/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
web_news_lookup_prompt,
web_wiki_lookup_prompt,
)
from core.tools.model_loader import load_model
from core.tools.model_loader import load_llm, load_embedder
from langchain_core.output_parsers import StrOutputParser

from tinydb import Query
Expand All @@ -27,7 +27,8 @@

output_parser = StrOutputParser()

llm, embeddings = load_model()
llm = load_llm()
embeddings = load_embedder()


# even though a single task takes a long time to complete,
Expand Down Expand Up @@ -108,14 +109,13 @@ def get_context(_: dict):

previous_queued_tasks = None

# todo: 1. get a list of available tasks, in the backend they'll be automatically set as executing
# 2. parse through all of them, until one that has all it's dependencies resolved appears
# 3. once one is found to be ready, release all the other tasks (reset 'executing')
# 4. proceed with normal execution
# 1. get a list of available tasks, in the backend they'll be automatically set as executing
# 2. parse through all of them, until one that has all it's dependencies resolved appears
# 3. once one is found to be ready, release all the other tasks (reset 'executing')
# 4. proceed with normal execution

# todo: implement class-based task management system


while True:
db = use_tinydb("completion_tasks")
db_query = Query()
Expand Down

0 comments on commit 4c543d5

Please sign in to comment.