-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
393 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import logging | ||
|
||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import DeclarativeBase | ||
|
||
engine = create_engine("sqlite://") | ||
|
||
logging.basicConfig() | ||
logging.getLogger("sqlalchemy.engine").setLevel(logging.CRITICAL) | ||
|
||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
|
||
def db_init(): | ||
Base.metadata.create_all(engine) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,124 @@ | ||
from tinydb import Query | ||
from sqlalchemy import String, Integer, Boolean, select, update | ||
from sqlalchemy.orm import Mapped, mapped_column, Session, relationship | ||
|
||
from core.databases import defaults | ||
from core.databases.db_base import Base, engine | ||
from core.databases.db_crawl_tasks import CrawlTask | ||
from core.tools import utils | ||
from core.tools.utils import use_tinydb, gen_unix_time | ||
from core.tools.utils import gen_unix_time, page_to_range | ||
|
||
db = use_tinydb("completion_tasks") | ||
|
||
class CompletionTask(Base): | ||
__tablename__ = "completion_tasks" | ||
|
||
def db_add_completion_task(prompt, mode): | ||
uuid: Mapped[str] = mapped_column(primary_key=True) | ||
prompt: Mapped[str] = mapped_column(String()) # make sure postgres uses "TEXT" here | ||
mode: Mapped[str] = mapped_column(String(12)) | ||
timestamp: Mapped[int] = mapped_column(Integer()) # time added | ||
completion_result: Mapped[str] = mapped_column(String()) # "TEXT" type here as well | ||
|
||
executing: Mapped[bool] = mapped_column(Boolean()) | ||
execution_date: Mapped[int] = mapped_column(Integer()) # time started completion | ||
|
||
completed: Mapped[bool] = mapped_column(Boolean()) | ||
completion_date: Mapped[int] = mapped_column(Integer()) # time completed | ||
|
||
required_crawl_tasks: Mapped[list["CrawlTask"]] = relationship() | ||
|
||
|
||
def db_add_completion_task(prompt, mode) -> str: | ||
new_uuid = utils.gen_uuid() | ||
timestamp = utils.gen_unix_time() | ||
|
||
db.insert( | ||
{ | ||
"uuid": new_uuid, | ||
"prompt": prompt, | ||
"mode": mode, | ||
"completed": False, | ||
"completion_result": None, | ||
"executing": False, | ||
"required_crawl_tasks": [], # uuid list that has to be completed first | ||
"completion_date": 0, | ||
"execution_date": 0, | ||
"timestamp": timestamp, | ||
} | ||
) | ||
with Session(engine) as session: | ||
completion_task = CompletionTask( | ||
uuid=new_uuid, | ||
prompt=prompt, | ||
mode=mode, | ||
timestamp=timestamp, | ||
executing=False, | ||
execution_date=0, | ||
completed=False, | ||
completion_date=0, | ||
required_crawl_tasks=[], | ||
) | ||
|
||
session.add(completion_task) | ||
session.commit() | ||
|
||
return new_uuid | ||
|
||
|
||
def db_get_completion_tasks_by_page(page: int, per_page: int = defaults.ITEMS_PER_PAGE): | ||
# returns all as TinyDB does not support pagination | ||
# we'll be moving to SQLite or Cassandra soon enough | ||
results = db.all() | ||
def db_get_completion_tasks_by_page( | ||
page: int, per_page: int = defaults.ITEMS_PER_PAGE | ||
) -> list[CompletionTask]: | ||
session = Session(engine) | ||
|
||
start, stop = page_to_range(page, per_page) | ||
|
||
query = select(CompletionTask).slice(start, stop) | ||
|
||
results = list(session.scalars(query)) | ||
return results | ||
|
||
|
||
def db_get_completion_tasks_by_uuid(uuid: int): | ||
fields = Query() | ||
result = db.get(fields.uuid == uuid) | ||
def db_get_completion_task_by_uuid(uuid: int) -> CompletionTask: | ||
session = Session(engine) | ||
|
||
query = select(CompletionTask).where(CompletionTask.uuid.is_(uuid)) | ||
|
||
result = session.scalars(query).one() | ||
return result | ||
|
||
|
||
def db_set_completion_task_executing(uuid: str): | ||
fields = Query() | ||
db.update( | ||
{"executing": True, "execution_date": gen_unix_time()}, fields.uuid == uuid | ||
session = Session(engine) | ||
|
||
session.execute( | ||
update(CompletionTask) | ||
.where(CompletionTask.uuid.is_(uuid)) | ||
.values(executing=True, execution_date=gen_unix_time()) | ||
) | ||
|
||
session.commit() | ||
|
||
|
||
def db_get_incomplete_completion_tasks(amount: int = 1): | ||
fields = Query() | ||
session = Session(engine) | ||
|
||
query = ( | ||
select(CompletionTask).where(CompletionTask.completed.is_(False)).limit(amount) | ||
) | ||
|
||
results = db.search(fields.completed == False and fields.executing == False) | ||
results = results[:amount] | ||
results = list(session.scalars(query).all()) | ||
|
||
for task in results: | ||
db_set_completion_task_executing(task["uuid"]) | ||
db_set_completion_task_executing(task.uuid) | ||
|
||
return results | ||
|
||
|
||
def db_release_executing_tasks(uuid_list: list[str]): | ||
fields = Query() | ||
db.update({"executing": False}, fields.uuid.one_of(uuid_list)) | ||
session = Session(engine) | ||
|
||
session.execute( | ||
update(CompletionTask) | ||
.where(CompletionTask.uuid.in_(uuid_list)) | ||
.values(executing=False, execution_date=0) | ||
) | ||
|
||
session.commit() | ||
|
||
|
||
def db_update_completion_task_after_summarizing(summary: str, uuid: str): | ||
fields = Query() | ||
db.update( | ||
{ | ||
"completed": True, | ||
"completion_result": summary, | ||
"completion_date": gen_unix_time(), | ||
}, | ||
fields.uuid == uuid, | ||
session = Session(engine) | ||
|
||
session.execute( | ||
update(CompletionTask) | ||
.where(CompletionTask.uuid.is_(uuid)) | ||
.values( | ||
completed=True, completion_result=summary, completion_date=gen_unix_time() | ||
) | ||
) | ||
|
||
session.commit() |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.