Skip to content

Commit

Permalink
Change database systems (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
latekvo authored Jun 12, 2024
1 parent 4d68174 commit 53b0385
Show file tree
Hide file tree
Showing 13 changed files with 393 additions and 214 deletions.
17 changes: 17 additions & 0 deletions core/databases/db_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import logging

from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase

engine = create_engine("sqlite://")

logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.CRITICAL)


class Base(DeclarativeBase):
pass


def db_init():
Base.metadata.create_all(engine)
129 changes: 87 additions & 42 deletions core/databases/db_completion_tasks.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,124 @@
from tinydb import Query
from sqlalchemy import String, Integer, Boolean, select, update
from sqlalchemy.orm import Mapped, mapped_column, Session, relationship

from core.databases import defaults
from core.databases.db_base import Base, engine
from core.databases.db_crawl_tasks import CrawlTask
from core.tools import utils
from core.tools.utils import use_tinydb, gen_unix_time
from core.tools.utils import gen_unix_time, page_to_range

db = use_tinydb("completion_tasks")

class CompletionTask(Base):
__tablename__ = "completion_tasks"

def db_add_completion_task(prompt, mode):
uuid: Mapped[str] = mapped_column(primary_key=True)
prompt: Mapped[str] = mapped_column(String()) # make sure postgres uses "TEXT" here
mode: Mapped[str] = mapped_column(String(12))
timestamp: Mapped[int] = mapped_column(Integer()) # time added
completion_result: Mapped[str] = mapped_column(String()) # "TEXT" type here as well

executing: Mapped[bool] = mapped_column(Boolean())
execution_date: Mapped[int] = mapped_column(Integer()) # time started completion

completed: Mapped[bool] = mapped_column(Boolean())
completion_date: Mapped[int] = mapped_column(Integer()) # time completed

required_crawl_tasks: Mapped[list["CrawlTask"]] = relationship()


def db_add_completion_task(prompt, mode) -> str:
new_uuid = utils.gen_uuid()
timestamp = utils.gen_unix_time()

db.insert(
{
"uuid": new_uuid,
"prompt": prompt,
"mode": mode,
"completed": False,
"completion_result": None,
"executing": False,
"required_crawl_tasks": [], # uuid list that has to be completed first
"completion_date": 0,
"execution_date": 0,
"timestamp": timestamp,
}
)
with Session(engine) as session:
completion_task = CompletionTask(
uuid=new_uuid,
prompt=prompt,
mode=mode,
timestamp=timestamp,
executing=False,
execution_date=0,
completed=False,
completion_date=0,
required_crawl_tasks=[],
)

session.add(completion_task)
session.commit()

return new_uuid


def db_get_completion_tasks_by_page(page: int, per_page: int = defaults.ITEMS_PER_PAGE):
# returns all as TinyDB does not support pagination
# we'll be moving to SQLite or Cassandra soon enough
results = db.all()
def db_get_completion_tasks_by_page(
page: int, per_page: int = defaults.ITEMS_PER_PAGE
) -> list[CompletionTask]:
session = Session(engine)

start, stop = page_to_range(page, per_page)

query = select(CompletionTask).slice(start, stop)

results = list(session.scalars(query))
return results


def db_get_completion_tasks_by_uuid(uuid: int):
fields = Query()
result = db.get(fields.uuid == uuid)
def db_get_completion_task_by_uuid(uuid: int) -> CompletionTask:
session = Session(engine)

query = select(CompletionTask).where(CompletionTask.uuid.is_(uuid))

result = session.scalars(query).one()
return result


def db_set_completion_task_executing(uuid: str):
fields = Query()
db.update(
{"executing": True, "execution_date": gen_unix_time()}, fields.uuid == uuid
session = Session(engine)

session.execute(
update(CompletionTask)
.where(CompletionTask.uuid.is_(uuid))
.values(executing=True, execution_date=gen_unix_time())
)

session.commit()


def db_get_incomplete_completion_tasks(amount: int = 1):
fields = Query()
session = Session(engine)

query = (
select(CompletionTask).where(CompletionTask.completed.is_(False)).limit(amount)
)

results = db.search(fields.completed == False and fields.executing == False)
results = results[:amount]
results = list(session.scalars(query).all())

for task in results:
db_set_completion_task_executing(task["uuid"])
db_set_completion_task_executing(task.uuid)

return results


def db_release_executing_tasks(uuid_list: list[str]):
fields = Query()
db.update({"executing": False}, fields.uuid.one_of(uuid_list))
session = Session(engine)

session.execute(
update(CompletionTask)
.where(CompletionTask.uuid.in_(uuid_list))
.values(executing=False, execution_date=0)
)

session.commit()


def db_update_completion_task_after_summarizing(summary: str, uuid: str):
fields = Query()
db.update(
{
"completed": True,
"completion_result": summary,
"completion_date": gen_unix_time(),
},
fields.uuid == uuid,
session = Session(engine)

session.execute(
update(CompletionTask)
.where(CompletionTask.uuid.is_(uuid))
.values(
completed=True, completion_result=summary, completion_date=gen_unix_time()
)
)

session.commit()
29 changes: 0 additions & 29 deletions core/databases/db_crawl_history.py

This file was deleted.

Loading

0 comments on commit 53b0385

Please sign in to comment.