From 4c543d58f0335c98874c7f6801d2d5d133d81283 Mon Sep 17 00:00:00 2001 From: LatekVon Date: Fri, 24 May 2024 13:18:51 +0200 Subject: [PATCH] replace all legacy connected loaders with split ones --- core/databases/db_embeddings.py | 7 ++----- core/lookup.py | 5 +++-- core/tools/model_loader.py | 16 ---------------- core/tools/scraper.py | 7 +++++-- core/tools/utils.py | 7 ++++--- workers/summarizer.py | 14 +++++++------- 6 files changed, 21 insertions(+), 35 deletions(-) diff --git a/core/databases/db_embeddings.py b/core/databases/db_embeddings.py index eb4b2b9..d2dda3b 100644 --- a/core/databases/db_embeddings.py +++ b/core/databases/db_embeddings.py @@ -1,16 +1,13 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from arguments import get_runtime_config -from core.tools.model_loader import load_model -from core.tools.utils import use_faiss, is_text_junk, remove_characters +from core.tools.utils import use_faiss, is_text_junk runtime_configuration = get_runtime_config() llm_config = runtime_configuration.llm_config embedder_config = runtime_configuration.embedder_config -_, embedder = load_model() - -vector_db = use_faiss("embeddings", embedder_config.model_name) +vector_db, embedder = use_faiss("embeddings", embedder_config.model_name) text_splitter = RecursiveCharacterTextSplitter( separators=embedder_config.buffer_stops, diff --git a/core/lookup.py b/core/lookup.py index 32ee1d8..9d74c2f 100644 --- a/core/lookup.py +++ b/core/lookup.py @@ -2,6 +2,7 @@ from langchain_core.runnables import RunnableLambda from arguments import get_runtime_config +from core.tools.model_loader import load_llm, load_embedder from core.tools.utils import purify_name from core.chainables.web import ( @@ -13,12 +14,12 @@ web_wiki_lookup_prompt, ) from core.tools.dbops import get_vec_db_by_name -from core.tools.model_loader import load_model output_parser = StrOutputParser() -llm, embeddings = load_model() +llm = load_llm() +embeddings = load_embedder() runtime_configuration = get_runtime_config() llm_config = runtime_configuration.llm_config diff --git a/core/tools/model_loader.py b/core/tools/model_loader.py index e4328e6..6d82093 100644 --- a/core/tools/model_loader.py +++ b/core/tools/model_loader.py @@ -25,10 +25,6 @@ def load_ollama_embedder(): return OllamaEmbeddings(model=embedder_config.model_name) -def load_ollama_model(): - return load_ollama_llm(), load_ollama_embedder() - - def load_hf_llm(): base_model_path = hf_hub_download( llm_config.model_file, filename=llm_config.model_name @@ -53,18 +49,6 @@ def load_hf_embedder(): ) -def load_hugging_face_model(): - return load_hf_llm(), load_hf_embedder() - - -def load_model(): - # todo: split up into separate llm and embedder functions - if llm_config.supplier == "hugging_face": - return load_hugging_face_model() - else: - return load_ollama_model() - - def load_llm(): if llm_config.supplier == "hugging_face": return load_hf_llm() diff --git a/core/tools/scraper.py b/core/tools/scraper.py index e370470..683fd0f 100644 --- a/core/tools/scraper.py +++ b/core/tools/scraper.py @@ -9,7 +9,7 @@ from colorama import Fore, Style from arguments import get_runtime_config -from core.tools.model_loader import load_model +from core.tools.model_loader import load_embedder, load_llm from core.tools.utils import purify_name from core.tools.dbops import get_vec_db_by_name from core.classes.query import WebQuery @@ -19,10 +19,13 @@ output_parser = StrOutputParser() runtime_configuration = get_runtime_config() + llm_config = runtime_configuration.llm_config embedder_config = runtime_configuration.embedder_config -llm, embeddings = load_model() +llm = load_llm() +embeddings = load_embedder() + embedding_model_safe_name = purify_name(embedder_config.model_name) diff --git a/core/tools/utils.py b/core/tools/utils.py index a6154c8..b20cac0 100644 --- a/core/tools/utils.py +++ b/core/tools/utils.py @@ -10,10 +10,11 @@ from core.databases import defaults from core.tools.dbops import get_vec_db_by_name -from core.tools.model_loader import load_model +from core.tools.model_loader import load_embedder def purify_name(name): + # fixme awful code lol return "_".join("_".join(name.split(":")).split("-")) @@ -117,12 +118,12 @@ def use_faiss(db_name, model_name): if not os.path.exists(data_path): os.makedirs(data_path) - _, embedder = load_model() + embedder = load_embedder() db_full_name = gen_vec_db_full_name(db_name, model_name) db = get_vec_db_by_name(db_full_name, embedder) - return db + return db, embedder class hide_prints: diff --git a/workers/summarizer.py b/workers/summarizer.py index ead7b4a..91c1f56 100644 --- a/workers/summarizer.py +++ b/workers/summarizer.py @@ -18,7 +18,7 @@ web_news_lookup_prompt, web_wiki_lookup_prompt, ) -from core.tools.model_loader import load_model +from core.tools.model_loader import load_llm, load_embedder from langchain_core.output_parsers import StrOutputParser from tinydb import Query @@ -27,7 +27,8 @@ output_parser = StrOutputParser() -llm, embeddings = load_model() +llm = load_llm() +embeddings = load_embedder() # even though a single task takes a long time to complete, @@ -108,14 +109,13 @@ def get_context(_: dict): previous_queued_tasks = None -# todo: 1. get a list of available tasks, in the backend they'll be automatically set as executing -# 2. parse through all of them, until one that has all it's dependencies resolved appears -# 3. once one is found to be ready, release all the other tasks (reset 'executing') -# 4. proceed with normal execution +# 1. get a list of available tasks, in the backend they'll be automatically set as executing +# 2. parse through all of them, until one that has all it's dependencies resolved appears +# 3. once one is found to be ready, release all the other tasks (reset 'executing') +# 4. proceed with normal execution # todo: implement class-based task management system - while True: db = use_tinydb("completion_tasks") db_query = Query()