Skip to content

Latest commit

 

History

History
106 lines (69 loc) · 3.01 KB

README.md

File metadata and controls

106 lines (69 loc) · 3.01 KB

pytorch-boilerplate : flashlight

The OTHER pytorch boilerplate.

Untitled

  • LightningModule [flashlight/runner/pl.py]

  • Trainer [flashlight/runner/main_pl.py]

  • Accelerators

  • Callback

  • Logging [flashlight/runner/pl.py]

  • Metrics

  • Plugins

Pre-requirements for local [PRTM!]

  • python 3.5 >
  • pytorch 1.5.0, torchvision 0.6.0 for your OS/CUDA match version
  • ... and install requirements.txt packages pip install -r requirements.txt

Getting Started

master branch runs MNIST classification (torchvision dataset) with squeezenet (torchvision model) for detail, check config/config.py

Run Single Experiment without NNI

  1. Prepare enviroment : gpu docker, local python env... whatever
  • if docker : docker pull davinnovation/pytorch-boilerplate:alpha
  1. python run.py or python -W ignore run.py

image

  1. after experiment... tensorboard --logdir Logs

image

Run Experiments with NNI

  1. Prepare environment

  2. nnictl create --config nni_config.yml

  3. localhost:8080

image

Diving into Code

image

  • Adding Network

flashlight.network.__init__.py

"""Network Define"""
# Add {"Network Name" : and nn.Module without initalize}
def _get_squeezenet(num_classes, version:str="1_0", pretrained=False, progress=True):
    VERSION = {
        "1_0" : torchvision.models.squeezenet1_0,
        "1_1" : torchvision.models.squeezenet1_1
    }

    return VERSION[version](pretrained=pretrained, progress=progress, num_classes=num_classes)

NETWORK_DICT = {
    "squeezenet": _get_squeezenet
}
  • Adding Dataset

flashlight.dataloader.__init__.py

""" Dataset """
# Add {Dataset Name : torch.utils.data.Dataset}
DATA_DICT = {"MNIST": torchvision.datasets.MNIST}

""" Dataset Transform """

transform = torchvision.transforms.Compose(
    [torchvision.transforms.Grayscale(num_output_channels=3), torchvision.transforms.ToTensor()]
)

def get_datalaoder(data, root="../datasets/", split="train"):
    if data in ["MNIST"]:  # if torchvision
        if split == "val":
            print(f"{data} dataset dosen't support validation set. val replaced by train")
        if split in ["train", "val"]:
            return DATA_DICT[data](root=root, train=True, download=True, transform=transform)
        else:
            return DATA_DICT[data](root=root, train=False, download=True, transform=transform)
  • Change Loss, forward/backward... [Research Code]

flashlight.runner.pl.py

  • Change Logger, hw options... [Engineering Code]

flashlight.runner.main_pl.py