Skip to content

Commit

Permalink
Merge pull request #17 from igorastashov/hw_3
Browse files Browse the repository at this point in the history
Hw 3
  • Loading branch information
igorastashov authored Dec 18, 2023
2 parents 62f23c4 + 359945c commit f4d7556
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/
# Folder
**/PokemonData
/runs
outputs

# Weights and optimizer
weights/model.pt
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ Example script to train and evaluate model.
python main.py
```

**Hydra - hyperparameter management tool.**

If you want to change hyper-parameters such as: `epoch_count`, `lr`, `batch_size`, `momentum`,
then you can change them in the file `conf/config.yaml`.

## (A) Acknowledgments

This repository borrows partially from [Isadrtdinov](https://github.com/isadrtdinov/intro-to-dl-hse/blob/2022-2023/seminars/201/seminar_04.ipynb), and [FUlyankin](https://github.com/FUlyankin/deep_learning_pytorch/tree/main/week08_fine_tuning) repositories.
Repository design taken from [v-goncharenko](https://github.com/v-goncharenko/data-science-template) and [PeterWang512](https://github.com/PeterWang512/CNNDetection).
Repository design taken from [v-goncharenko](https://github.com/v-goncharenko/data-science-template), [PeterWang512](https://github.com/PeterWang512/CNNDetection) and [ArjanCodes](https://github.com/ArjanCodes/2021-config).
9 changes: 9 additions & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- _self_
paths:
data: ${hydra:runtime.cwd}/data/PokemonData
params:
epoch_count: 10
lr: 1e-2
batch_size: 32
momentum: 0.9
20 changes: 20 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass


@dataclass
class Paths:
data: str


@dataclass
class Params:
epoch_count: int
lr: float
batch_size: int
momentum: float


@dataclass
class ConvNetConfig:
paths: Paths
params: Params
11 changes: 6 additions & 5 deletions ds/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from tqdm import tqdm


def remove_bed_images(data_dir: Path):
bad_images = glob.glob(f"{data_dir}/*/*.svg")
def remove_bed_images(root_path: str):
data_path = Path(f"{root_path}")
bad_images = glob.glob(f"{data_path}/*/*.svg")
for bad_image in bad_images:
os.remove(bad_image)

Expand Down Expand Up @@ -109,21 +110,21 @@ def __getitem__(self, item):


def create_dataloader(
root: Path,
root_path: str,
batch_size: int,
load_to_ram: bool = False,
pin_memory: bool = True,
num_workers: int = 2,
) -> tuple[DataLoader[Any], DataLoader[Any]]:
train_dataset = PokemonDataset(
root=root,
root=Path(f"{root_path}"),
train=True,
load_to_ram=load_to_ram,
transform=prepare_train_data(),
)

test_dataset = PokemonDataset(
root=root,
root=Path(f"{root_path}"),
train=False,
load_to_ram=load_to_ram,
transform=prepare_test_data(),
Expand Down
40 changes: 22 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
import pathlib

import hydra
import torch
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf

from config import ConvNetConfig
from ds.dataset import create_dataloader, remove_bed_images
from ds.models import ConvNet
from ds.runner import train


# Hyper parameters
EPOCH_COUNT = 10
LR = 1e-2
MOMENTUM = 0.9
BATCH_SIZE = 32

# Data configuration
DATA_DIR = pathlib.Path("data/PokemonData")

# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def main():
cs = ConfigStore.instance()
cs.store(name="ConfNet_config", node=ConvNetConfig)


@hydra.main(config_path="conf", config_name="config")
def main(cfg: ConvNetConfig):
print(OmegaConf.to_yaml(cfg))

# Model and Optimizer
model_name = "ConvNet"
model = ConvNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), LR, MOMENTUM)
optimizer = torch.optim.SGD(
model.parameters(), lr=cfg.params.lr, momentum=cfg.params.momentum
)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCH_COUNT)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=cfg.params.epoch_count
)

# Remove bad images
remove_bed_images(DATA_DIR)
remove_bed_images(cfg.paths.data)

# Create the data loaders
train_loader, test_loader = create_dataloader(
root=DATA_DIR,
batch_size=BATCH_SIZE,
root_path=cfg.paths.data,
batch_size=cfg.params.batch_size,
load_to_ram=False,
pin_memory=True,
num_workers=2,
Expand All @@ -48,7 +52,7 @@ def main():
criterion,
train_loader,
test_loader,
num_epochs=EPOCH_COUNT,
num_epochs=cfg.params.epoch_count,
device=device,
title=model_name,
)
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ isort = "^5.13.0"
flake8 = "^6.1.0"
pre-commit = "^3.6.0"
dvc = {extras = ["gdrive"], version = "^3.33.4"}
hydra-core = "^1.3.2"

[tool.black]
line-length = 90
Expand Down
4 changes: 2 additions & 2 deletions weights/download_weights.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
gdown --id 1R7NCduqaDK_D2R4mgUSvgZubcX5DjSYr -O model.pt
gdown --id 1Ds3Hch72xXUyEV0_aiV8TwNJp8Pg3J9H -O optimizer.pt
gdown --id 1ROwjrRiM_EfqegJEl2RxW3pM9ieK80Cv -O model.pt
gdown --id 11FzTrBeYoj3MQ8bmnEZ6BDP6JtXjFylz -O optimizer.pt

0 comments on commit f4d7556

Please sign in to comment.