From c2ce047b602dc3efaa14145e0095a8d7f74ffed6 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Tue, 17 Sep 2024 10:15:12 -0700 Subject: [PATCH] Initial commit. --- example_config.toml | 72 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 ++ src/kbmod_ml/models/cnn.py | 74 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+) create mode 100644 example_config.toml create mode 100644 src/kbmod_ml/models/cnn.py diff --git a/example_config.toml b/example_config.toml new file mode 100644 index 0000000..fa1bb61 --- /dev/null +++ b/example_config.toml @@ -0,0 +1,72 @@ +[general] +use_gpu = true + +# Destination of log messages +# 'stderr' and 'stdout' specify the console. +log_destination = "stderr" +# A path name specifies a file e.g. +# log = "fibad_log.txt" + +# Lowest log level to emit. +# As you go down the list, fibad will become more verbose in the log. +# +# log_level = "critical" # Only emit the most severe of errors +# log_level = "error" # Emit all errors +# log_level = "warning" # Emit warnings and all errors +log_level = "info" # Emit informational messages, warnings and all errors +# log_level = "debug" # Very verbose, emit all log messages. + +[download] +sw = "22asec" +sh = "22asec" +filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] +type = "coadd" +rerun = "pdr3_wide" +username = "mtauraso@local" +password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1" +max_connections = 2 +fits_file = "../hscplay/temp.fits" +cutout_dir = "../hscplay/cutouts/" +offset = 0 +num_sources = 500 + +# These control the downloader's HTTP requests and retries +# `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s +retry_wait = 30 +# `retries` How many times to retry a failed HTTP request before moving on to the next one. Default 3 times +retries = 3 +# `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr) +timeout = 3600 +# `chunksize` How many sky location rectangles should we request in a single request. Default is 990 +chunksize = 990 + +[model] +# name = "ExampleCNN" +# name = "ExampleAutoencoder" + +# An example of requesting an external model class +# external_class = "user_package.submodule.ExternalModel" +external_cls = "kbmod_ml.models.cnn.CNN" + +weights_filepath = "example_model.pth" +epochs = 10 + +[data_loader] +# Name of data loader to use +name = "CifarDataLoader" +# name = "HSCDataLoader" + +# An example of requesting an external data loader class +# external_class = "user_package.submodule.ExternalDataLoader" + +# Directory path where the data is stored +path = "/Users/drew/code/fibad/data/cifar-10-batches-py" +# path = "/Users/drew/code/fibad/data/hsc-samples" + +# Default PyTorch DataLoader parameters +batch_size = 10 +shuffle = true +num_workers = 10 + +[predict] +batch_size = 32 diff --git a/pyproject.toml b/pyproject.toml index f45ec2d..d34ca62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ + "torch", # PyTorch + # "fibad", when it is available on PyPI ] [project.urls] diff --git a/src/kbmod_ml/models/cnn.py b/src/kbmod_ml/models/cnn.py new file mode 100644 index 0000000..3b62715 --- /dev/null +++ b/src/kbmod_ml/models/cnn.py @@ -0,0 +1,74 @@ +# ruff: noqa: D101, D102 + +# This example model is taken from the PyTorch CIFAR10 tutorial: +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa N812 +import torch.optim as optim +from fibad.models.model_registry import fibad_model + +logger = logging.getLogger(__name__) + + +@fibad_model +class CNN(nn.Module): + def __init__(self, model_config, shape): + logger.info("This is an external model, not in FIBAD!!!") + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + self.config = model_config + + # Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)` + # but we define them as methods as a way to allow for more flexibility in the future. + self.optimizer = self._optimizer() + self.criterion = self._criterion() + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def train_step(self, batch): + """This function contains the logic for a single training step. i.e. the + contents of the inner loop of a ML training process. + + Parameters + ---------- + batch : tuple + A tuple containing the inputs and labels for the current batch. + + Returns + ------- + Current loss value + The loss value for the current batch. + """ + inputs, labels = batch + + self.optimizer.zero_grad() + outputs = self(inputs) + loss = self.criterion(outputs, labels) + loss.backward() + self.optimizer.step() + return {"loss": loss.item()} + + def _criterion(self): + return nn.CrossEntropyLoss() + + def _optimizer(self): + return optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + def save(self): + torch.save(self.state_dict(), self.config.get("weights_filepath", "example_cnn.pth"))