From d18d05b8afd9d1faeaaa3b00010222c96cb1f31b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacy=20=C5=81=C4=85tka?= Date: Tue, 21 May 2024 00:06:34 +0200 Subject: [PATCH] Add crawl task completion tracking (#29) --- core/databases/db_completion_tasks.py | 16 ++++-- core/databases/db_crawl_tasks.py | 65 ++++++++++++++++++++++- core/databases/db_embeddings.py | 4 ++ core/databases/db_url_pool.py | 3 +- workers/crawler.py | 21 ++++++-- workers/embedder.py | 7 ++- workers/summarizer.py | 74 +++++++++++++++++++++------ 7 files changed, 162 insertions(+), 28 deletions(-) diff --git a/core/databases/db_completion_tasks.py b/core/databases/db_completion_tasks.py index a46ee29..4d6ee23 100644 --- a/core/databases/db_completion_tasks.py +++ b/core/databases/db_completion_tasks.py @@ -19,6 +19,7 @@ def db_add_completion_task(prompt, mode): "completed": False, "completion_result": None, "executing": False, + "required_crawl_tasks": [], # uuid list that has to be completed first "completion_date": 0, "execution_date": 0, "timestamp": timestamp, @@ -43,16 +44,23 @@ def db_set_completion_task_executing(uuid: str): ) -def db_get_incomplete_completion_task(): +def db_get_incomplete_completion_tasks(amount: int = 1): fields = Query() - results = db.get(fields.completed == False and fields.executing == False) - if results is not None: - db_set_completion_task_executing(results["uuid"]) + results = db.search(fields.completed == False and fields.executing == False) + results = results[:amount] + + for task in results: + db_set_completion_task_executing(task["uuid"]) return results +def db_release_executing_tasks(uuid_list: list[str]): + fields = Query() + db.update({"executing": False}, fields.uuid.one_of(uuid_list)) + + def db_update_completion_task_after_summarizing(summary: str, uuid: str): fields = Query() db.update( diff --git a/core/databases/db_crawl_tasks.py b/core/databases/db_crawl_tasks.py index 717a2df..f1cb2e3 100644 --- a/core/databases/db_crawl_tasks.py +++ b/core/databases/db_crawl_tasks.py @@ -1,3 +1,5 @@ +from typing import Literal + from tinydb import Query from core.tools import utils @@ -8,7 +10,8 @@ # we have to heartbeat our workers once we run out of tasks, websocks should suffice -def db_add_crawl_task(prompt): +def db_add_crawl_task(prompt: str, mode: Literal["news", "wiki", "docs"] = "wiki"): + # todo: replace arguments with a single WebQuery new_uuid = utils.gen_uuid() timestamp = utils.gen_unix_time() @@ -16,12 +19,14 @@ def db_add_crawl_task(prompt): { "uuid": new_uuid, "prompt": prompt, - "type": None, # todo: choose 'news', 'wiki', 'docs', use WebQuery + "type": mode, "completed": False, "executing": False, "completion_date": 0, # time completed "execution_date": 0, # time started completion "timestamp": timestamp, # time added + "base_amount_scheduled": 100, # todo: replace with dynamically adjusted value + "embedding_progression": {}, # {model_name: count} | progress tracking } ) @@ -58,3 +63,59 @@ def db_get_incomplete_crawl_task(): db.update({"executing": True}, fields.uuid == task.uuid) return task + + +def db_is_task_completed(uuid: str): + fields = Query() + task = db.get(fields.uuid == uuid) + + return task.completed + + +def db_are_tasks_completed(uuid_list: list[str]): + # fixme: instead of multiple individual calls, make one composite one + # for our current usage this is not necessary + + total_completeness = True + + for uuid in uuid_list: + task_completeness = db_is_task_completed(uuid) + total_completeness *= task_completeness + + pass + + +def db_is_crawl_task_fully_embedded(uuid: str, model_name: str): + fields = Query() + task = db.get(fields.uuid == uuid) + + baseline_count = task.base_amount_scheduled + current_count = task.embedding_progression[model_name] + + return current_count >= baseline_count + + +def db_are_crawl_tasks_fully_embedded(uuid_list: str, model_name: str): + # todo: replace this naive approach with a one-query solution + for uuid in uuid_list: + if db_is_crawl_task_fully_embedded(uuid, model_name) is False: + return False + + return True + + +def db_increment_task_embedding_progression(uuid: str, model_name: str): + fields = Query() + task = db.get(fields.uuid == uuid) + + current_progression = task.embedding_progression + current_count = current_progression[model_name] + + if current_count is not None: + current_count += 1 + else: + current_count = 1 + + current_progression[model_name] = current_count + + db.update({"embedding_progression": current_progression}, fields.uuid == task.uuid) diff --git a/core/databases/db_embeddings.py b/core/databases/db_embeddings.py index e633052..409ef78 100644 --- a/core/databases/db_embeddings.py +++ b/core/databases/db_embeddings.py @@ -18,6 +18,10 @@ ) +def db_get_currently_used_vector_model(): + return embedder_config.model_name + + def db_add_text_batch(text: str, db_full_name: str): # automatically splits text before embedding it chunks = text_splitter.split_text(text) diff --git a/core/databases/db_url_pool.py b/core/databases/db_url_pool.py index 89b441e..6f3a0f0 100644 --- a/core/databases/db_url_pool.py +++ b/core/databases/db_url_pool.py @@ -13,13 +13,14 @@ # and a tiny global kv cache just to prevent duplicate urls -def db_add_url(url: str, prompt: str, parent_uuid: str = None): +def db_add_url(url: str, prompt: str, parent_uuid: str = None, task_uuid: str = None): new_uuid = utils.gen_uuid() timestamp = utils.gen_unix_time() new_url_object = { "uuid": new_uuid, "parent_uuid": parent_uuid, + "task_uuid": task_uuid, "prompt": prompt, "url": url, "text": None, diff --git a/workers/crawler.py b/workers/crawler.py index 1c68668..fabd66c 100644 --- a/workers/crawler.py +++ b/workers/crawler.py @@ -1,4 +1,3 @@ -import os from urllib.error import HTTPError from langchain_community.document_loaders import WebBaseLoader, PyPDFLoader @@ -64,12 +63,21 @@ def rq_refill(seed_task, use_google: bool = True): if not quit_unexpectedly: try: for url in google_urls: + if db_url_pool.db_is_url_present(url): continue + prompt = seed_query.web_query - new_url_object = db_url_pool.db_add_url(url, prompt, None) + new_url_object = db_url_pool.db_add_url( + url=url, + prompt=prompt, + parent_uuid=None, + task_uuid=seed_task.uuid, + ) + google_url_objects.append(new_url_object) idx += 1 + google_traffic_manager.report_no_timeout() except HTTPError: # google requires a long delay after getting timeout @@ -90,13 +98,15 @@ def rq_refill(seed_task, use_google: bool = True): return -def url_save(url: str, parent_id: str = None): +def url_save(url: str, parent_uuid: str = None, task_uuid: str = None): # 0. check if url was already saved if db_url_pool.db_is_url_present(url): return # 1. add to the db - db_url_pool.db_add_url(url, "N/A", parent_id) + db_url_pool.db_add_url( + url=url, prompt="N/A", parent_uuid=parent_uuid, task_uuid=task_uuid + ) def url_download_text(url: str): @@ -130,6 +140,7 @@ def url_download(url_object): def process_url(url_object): url_uuid = url_object["uuid"] + url_task_uuid = url_object["task_uuid"] # 0. download article document_text = url_download(url_object) @@ -143,7 +154,7 @@ def process_url(url_object): # 2. save all links for link in url_list: - url_save(link, url_uuid) + url_save(url=link, parent_uuid=url_uuid, task_uuid=url_task_uuid) def processing_iteration(): diff --git a/workers/embedder.py b/workers/embedder.py index 546f10e..c151a46 100644 --- a/workers/embedder.py +++ b/workers/embedder.py @@ -4,7 +4,7 @@ from tinydb import Query from tinydb.table import Document -from core.databases import db_url_pool, db_embeddings +from core.databases import db_url_pool, db_embeddings, db_crawl_tasks from core.models.configurations import load_llm_config from core.tools import utils @@ -27,10 +27,15 @@ def processing_iteration(): for url_object in embedding_queue: print("embedding document:", url_object) document = url_object["text"] + task_uuid = url_object["task_uuid"] db_full_name = utils.gen_vec_db_full_name("embeddings", embed_model_name) + db_embeddings.db_add_text_batch(document, db_full_name) db_url_pool.db_set_url_embedded(url_object["uuid"], embed_model_name) + db_crawl_tasks.db_increment_task_embedding_progression( + task_uuid, embed_model_name + ) print(f"{Fore.CYAN}Document vectorization completed.{Fore.RESET}") diff --git a/workers/summarizer.py b/workers/summarizer.py index f10019f..ead7b4a 100644 --- a/workers/summarizer.py +++ b/workers/summarizer.py @@ -1,7 +1,15 @@ -from core.databases.db_embeddings import db_search_for_similar_queries +from core.databases.db_crawl_tasks import ( + db_are_tasks_completed, + db_are_crawl_tasks_fully_embedded, +) +from core.databases.db_embeddings import ( + db_search_for_similar_queries, + db_get_currently_used_vector_model, +) from core.databases.db_completion_tasks import ( - db_get_incomplete_completion_task, + db_get_incomplete_completion_tasks, db_update_completion_task_after_summarizing, + db_release_executing_tasks, ) from langchain_core.runnables import RunnableLambda from core.classes.query import WebQuery @@ -22,38 +30,66 @@ llm, embeddings = load_model() +# even though a single task takes a long time to complete, +# as soon as one task is started, all elements of the queue are released +task_queue = [] +task_queue_limit = 10 + + +def extract_uuid(task): + return task["uuid"] + + def summarize(): + global task_queue + + queue_space = task_queue_limit - len(task_queue) + task_queue += db_get_incomplete_completion_tasks(queue_space) - task = db_get_incomplete_completion_task() + current_task = None - if task is None: + current_vec_db_model = db_get_currently_used_vector_model() + + # find the first task ready for execution, dismiss the others + for task in task_queue: + # check all dependencies for completeness + dep_list = task["required_crawl_tasks"] + + if db_are_crawl_tasks_fully_embedded(dep_list, current_vec_db_model): + current_task = task + task_queue.remove(task) + task_uuid_list = list(map(extract_uuid, task_queue)) + db_release_executing_tasks(task_uuid_list) + + if current_task is None: return - def get_query(): - return WebQuery(task["mode"].lower(), prompt_core=task["prompt"]) + task_query = WebQuery( + prompt_core=current_task["prompt"], query_type=current_task["mode"].lower() + ) - context = db_search_for_similar_queries(get_query()) + context = db_search_for_similar_queries(task_query) if context is None: return def interpret_prompt_mode(): - if task["mode"] == "News": + if current_task["mode"] == "News": return web_news_lookup_prompt() - elif task["mode"] == "Docs": + elif current_task["mode"] == "Docs": return web_docs_lookup_prompt() - elif task["mode"] == "Wiki": + elif current_task["mode"] == "Wiki": return web_wiki_lookup_prompt() def get_user_prompt(_: dict): - return task["prompt"] + return current_task["prompt"] def get_context(_: dict): return context[0].page_content web_interpret_prompt_mode = interpret_prompt_mode() - print("Summarizing task with uuid: ", task["uuid"]) + print("Summarizing task with uuid: ", current_task["uuid"]) chain = ( { "search_data": RunnableLambda(get_context), @@ -64,14 +100,22 @@ def get_context(_: dict): | llm | output_parser ) - summary = chain.invoke(task) - db_update_completion_task_after_summarizing(summary, task["uuid"]) + summary = chain.invoke(current_task) + db_update_completion_task_after_summarizing(summary, current_task["uuid"]) - print(f"{Fore.CYAN}Completed task with uuid: {Fore.RESET}", task["uuid"]) + print(f"{Fore.CYAN}Completed task with uuid: {Fore.RESET}", current_task["uuid"]) previous_queued_tasks = None +# todo: 1. get a list of available tasks, in the backend they'll be automatically set as executing +# 2. parse through all of them, until one that has all it's dependencies resolved appears +# 3. once one is found to be ready, release all the other tasks (reset 'executing') +# 4. proceed with normal execution + +# todo: implement class-based task management system + + while True: db = use_tinydb("completion_tasks") db_query = Query()