Skip to content

Commit

Permalink
feat: implement training loop running in background thread (#3, #8)
Browse files Browse the repository at this point in the history
  • Loading branch information
notnitsuj committed Feb 12, 2024
1 parent 20c9b2e commit 416b4aa
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 40 deletions.
31 changes: 27 additions & 4 deletions src/backend/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from enums import Status
from database.io import create_tasks, Logger, reorder_backlog_queue
from database.models import *
from database.db import instantiate_db, get_session
from typing import List
from threading import Thread

from fastapi import FastAPI, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlmodel import Session, select
import uvicorn

from ml.utils import seed_everything
from database.io import create_tasks, reorder_backlog_queue, Logger
from database.models import *
from database.db import instantiate_db, get_session, engine
from service.train import start_training_thread
from enums import Status


app = FastAPI()
app.add_middleware(
Expand All @@ -19,9 +23,28 @@
)


class loop:
def __init__(self) -> None:
self.is_loop = True


looper = loop()

logger = Logger()
thread = Thread(target=start_training_thread, args=(engine, logger, looper))


@app.on_event("startup")
def on_startup():
seed_everything()
instantiate_db()
thread.start()


@app.on_event("shutdown")
def on_shutdown():
looper.is_loop = False
thread.join()


@app.get("/jobs/", response_model=List[JobReadWithTasks])
Expand Down
Empty file.
Empty file.
40 changes: 40 additions & 0 deletions src/backend/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,46 @@ def instantiate_db():
if not os.path.exists(sqlite_file_name):
SQLModel.metadata.create_all(engine)

from .models import Job, Task

with Session(engine) as session:
job1 = Job(type=0, strategy=0, backlog_order=1)
job2 = Job(type=1, strategy=1, backlog_order=2)
job3 = Job(type=1, strategy=2, backlog_order=3)
job4 = Job(type=0, strategy=0, status=1, queue_order=1)
job5 = Job(type=1, strategy=1, status=1, queue_order=2)

session.add(job1)
session.add(job2)
session.add(job3)
session.add(job4)
session.add(job5)
session.commit()
session.refresh(job1)
session.refresh(job2)
session.refresh(job3)

task1 = Task(job_id=job1.id, train_batch_size=1024,
test_batch_size=512, use_gpu=True)
task2 = Task(job_id=job2.id, lr=1e-4, epoch=100,
accuracy=0.7, avg_precision=0.8, avg_recall=0.9, runtime=2323)
task3 = Task(job_id=job2.id, lr=1e-5, epoch=200,
accuracy=0.4, avg_precision=0.6, avg_recall=0.7, runtime=23231)
task4 = Task(job_id=job3.id, lr=1e-2, epoch=20,
accuracy=0.9, avg_precision=0.6, avg_recall=0.72, runtime=20232)
task5 = Task(job_id=job3.id, lr=3e-4,
epoch=150, train_batch_size=256, accuracy=0.85, avg_precision=0.43, avg_recall=0.5, runtime=23235)
task6 = Task(job_id=job3.id, lr=1e-2,
epoch=90, train_batch_size=15, accuracy=0.34, avg_precision=0.62, avg_recall=0.53, runtime=20012)
task7 = Task(job_id=job3.id, lr=5e-3, epoch=300, use_gpu=True,
accuracy=0.93, avg_precision=0.2, avg_recall=0.45, runtime=232392)
task8 = Task(job_id=job3.id, lr=1e-2, epoch=100, use_gpu=True,
accuracy=0.26, avg_precision=0.28, avg_recall=0.64, runtime=12432)

session.add_all((task1, task2, task3, task4,
task5, task6, task7, task8))
session.commit()


def get_session():
with Session(engine) as session:
Expand Down
81 changes: 80 additions & 1 deletion src/backend/database/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import json
from typing import Iterable

from sqlmodel import Session, select
Expand All @@ -7,7 +9,84 @@


class Logger:
...
def __init__(self):
self.__session = None
self.__db_dir = os.path.dirname(os.path.abspath(__file__))
self.data_dir = os.path.join(self.__db_dir, "data/")
self.checkpoint_dir = os.path.join(self.__db_dir, "checkpoints/")
self.log_dir = os.path.join(self.__db_dir, "logs/")

self.checkpoint_path = None
self.log_path = None

self.task = Task()

self.train_logs = []
self.test_logs = []
self.runtime = None

def register_session(self, session: Session):
self.__session = session

def register_task(self, task: Task):
self.task = task

self.checkpoint_path = self.checkpoint_dir + \
f"Job{self.task.job_id}_Task{self.task.id}.pt"
self.task.checkpoint = self.checkpoint_path

self.log_path = self.log_dir + \
f"Job{self.task.job_id}_Task{self.task.id}.json"
self.task.logs = self.log_path

self.task.status = Status.RUNNING.value

self.update_db()

def finish(self, runtime: float):
self.task.runtime = runtime
self.update_db()

self.save_logs()

self.checkpoint_path = None
self.log_path = None

self.task = Task()

self.train_logs = []
self.val_logs = []
self.runtime = None

def log_train(self, epoch: int, step: int, loss: float, time: float):
self.train_logs.append({
"epoch": epoch,
"step": step,
"loss": loss,
"time": time
})

def log_test(self, epoch: int, step: int, loss: float, time: float):
self.test_logs.append({
"epoch": epoch,
"step": step,
"loss": loss,
"time": time
})

def update_db(self):
self.__session.add(self.task)
self.__session.commit()
self.__session.refresh(self.task)

def save_logs(self):
logs = {
"train": self.train_logs,
"test": self.test_logs
}

with open(self.log_path, 'w', encoding='utf-8') as f:
json.dump(logs, f, ensure_ascii=False, indent=4)


def create_tasks(job: Job) -> Iterable[object]:
Expand Down
Empty file.
6 changes: 3 additions & 3 deletions src/backend/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class Task(TaskBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
checkpoint: Optional[str] = Field(default=None)
logs: Optional[str] = Field(default=None)
accuracy: Optional[float] = Field(default=None)
avg_precision: Optional[float] = Field(default=None)
avg_recall: Optional[float] = Field(default=None)
accuracy: Optional[float] = Field(default=0)
avg_precision: Optional[float] = Field(default=0)
avg_recall: Optional[float] = Field(default=0)
runtime: Optional[int] = Field(default=0)

job: Job = Relationship(back_populates="tasks")
Expand Down
11 changes: 5 additions & 6 deletions src/backend/ml/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@


class TrainerArguments(BaseModel):
lr: int = 1e-3
train_batch_size: int = 64
test_batch_size: int = 128
lr: int = 6e-3
train_batch_size: int = 512
test_batch_size: int = 1024
epoch: int = 50
dropout_rate: float | None = None
transform: list | None = None
optimizer: int | None = None
scheduler: tuple[int, dict[str, float]
] | list[int, dict[str, float]] | None = None
# scheduler: list[int, dict[str, float]] | None = None
cleanlab: bool = False
use_gpu: bool = False
use_gpu: bool = True
19 changes: 19 additions & 0 deletions src/backend/ml/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import random

import numpy as np
import torch


def seed_everything(seed: int = 2024):
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
g = torch.Generator()
g.manual_seed(seed)
75 changes: 49 additions & 26 deletions src/backend/ml/worker.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
import time

import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adadelta
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets.vision import VisionDataset
import torchvision

from .args import TrainerArguments
from ..enums import OptimizerClass, SchedulerClass
from enums import OptimizerClass, SchedulerClass
from .model import SimpleModel
from database.io import Logger


class Trainer:
def __init__(self,
model: nn.Module,
train_set: VisionDataset,
test_set: VisionDataset,
trainer_args: TrainerArguments,
logger):

self.model = model
self.train_set = train_set
self.test_set = test_set
self.trainer_args = trainer_args
def __init__(self, logger: Logger):

self.model = SimpleModel()

self.train_set = torchvision.datasets.MNIST(
root=logger.data_dir, train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
self.test_set = torchvision.datasets.MNIST(
root=logger.data_dir, train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))

self.logger = logger

self.train_args = {"batch_size": self.trainer_args.train_batch_size}
self.test_args = {"batch_size": self.trainer_args.test_batch_size}
self.train_args = {"batch_size": self.logger.task.train_batch_size}
self.test_args = {"batch_size": self.logger.task.test_batch_size}

if self.trainer_args.use_gpu and torch.cuda.is_available():
if self.logger.task.use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")

cuda_args = {"num_workers": 1,
Expand All @@ -46,20 +54,25 @@ def __init__(self,
self.test_loader = torch.utils.data.DataLoader(
self.test_set, **self.test_args)

self.__register_optimizer(self.trainer_args.optimizer)
self.__register_scheduler(self.trainer_args.scheduler)
self.__register_optimizer(self.logger.task.optimizer)
self.__register_scheduler(self.logger.task.scheduler)

self.train_step = 1
self.test_step = 1

def train(self):

self.start_time = time.time()

for epoch in range(1, self.trainer_args.epoch + 1):
for epoch in range(1, self.logger.task.epoch + 1):
self.__train(epoch)
self.__test(epoch)
self.scheduler.step()

torch.save(self.model.state_dict(), self.logger.checkpoint_path)

self.logger.finish(runtime=time.time() - self.start_time)

def __train(self, epoch):

self.model.train()
Expand All @@ -74,7 +87,10 @@ def __train(self, epoch):

current_time = int(time.time() - self.start_time)
self.logger.log_train(
epoch=epoch, loss=loss.item(), time=current_time)
epoch=epoch, step=self.train_step, loss=loss.item(), time=current_time)
self.train_step += 1

print(f"Step {self.train_step}, loss {loss.item()}", flush=True)

def __test(self, epoch):

Expand All @@ -93,23 +109,30 @@ def __test(self, epoch):
correct += pred.eq(target.view_as(pred)).sum().item()

current_time = int(time.time() - self.start_time)
self.logger.log_test(epoch=epoch, loss=loss, time=current_time)
self.logger.log_test(
epoch=epoch, step=self.test_step, loss=loss, time=current_time)
self.test_step += 1

print(f"Step {self.test_step}, loss {loss}", flush=True)

test_loss /= len(self.test_loader.dataset)
accuracy = correct / len(self.test_loader.dataset)

print(
f"Epoch {epoch}, test loss: {test_loss}, accuracy {accuracy}", flush=True)

current_time = int(time.time() - self.start_time)
self.logger.log_accuracy(
epoch=epoch, accuracy=accuracy, time=current_time)
epoch=epoch, avg_loss=test_loss, accuracy=accuracy, avg_precision=accuracy, avg_recall=accuracy) # TODO: Add precision and recall

def __register_optimizer(self, optim):
if not optim:
self.optimizer = Adadelta(self.model.parameters(),
lr=self.trainer_args.lr)
lr=self.logger.task.lr)

optimizer_class = OptimizerClass(optim).value
self.optimizer = optimizer_class(
self.model.parameters(), lr=self.trainer_args.lr)
self.model.parameters(), lr=self.logger.task.lr)

def __register_scheduler(self, scheduler):
if not scheduler:
Expand Down
Loading

0 comments on commit 416b4aa

Please sign in to comment.