From fe6fe6488b9934491897929d7cb7e998d40ba7ff Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Mon, 23 Sep 2024 13:48:52 -0700 Subject: [PATCH] Updating to match changes in fibad. --- example_config.toml | 49 ++++++++++++++++++++++++++++---------- src/kbmod_ml/models/cnn.py | 4 ++-- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/example_config.toml b/example_config.toml index 768ef00..ddf500c 100644 --- a/example_config.toml +++ b/example_config.toml @@ -16,19 +16,21 @@ log_destination = "stderr" log_level = "info" # Emit informational messages, warnings and all errors # log_level = "debug" # Very verbose, emit all log messages. +data_dir = "/home/drew/code/fibad/data/" + [download] sw = "22asec" sh = "22asec" filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] type = "coadd" rerun = "pdr3_wide" -username = "mtauraso@local" -password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1" -max_connections = 2 -fits_file = "../hscplay/temp.fits" -cutout_dir = "../hscplay/cutouts/" +username = false +password = false +num_sources = -1 # Values below 1 here indicate all sources in the catalog will be downloaded offset = 0 -num_sources = 500 +concurrent_connections = 4 +stats_print_interval = 60 +fits_file = "./catalog.fits" # These control the downloader's HTTP requests and retries # `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s @@ -38,7 +40,14 @@ retries = 3 # `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr) timeout = 3600 # `chunksize` How many sky location rectangles should we request in a single request. Default is 990 -chunksize = 990 +chunk_size = 990 + +# Whether to retrieve the image layer +image = true +# Whether to retrieve the variance layer +variance = false +# Whether to retrieve the mask layer +mask = false [model] # The name of the built-in model to use or the libpath to an external model @@ -48,19 +57,35 @@ name = "kbmod_ml.models.cnn.CNN" weights_filepath = "example_model.pth" epochs = 10 +base_channel_size = 32 +latent_dim =64 + [data_loader] # Name of the built-in data loader to use or the libpath to an external data loader # e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader" name = "CifarDataLoader" -# name = "HSCDataLoader" -# Directory path where the data is stored -path = "/home/drew/code/fibad/data/" + +# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small. +# +# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and +# uses those pixel dimensions as the crop size. +# +#crop_to = [100,100] +crop_to = false + +# Limit data loader to only particular filters when there are more in the data set. +# +# When not provided by the user, the number of filters will be automatically gleaned from the data set. +# Defaults behavior is produced by the false value. +# +#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] +filters = false # Default PyTorch DataLoader parameters -batch_size = 10 +batch_size = 4 shuffle = true -num_workers = 10 +num_workers = 2 [predict] batch_size = 32 diff --git a/src/kbmod_ml/models/cnn.py b/src/kbmod_ml/models/cnn.py index 3b62715..9d6d39e 100644 --- a/src/kbmod_ml/models/cnn.py +++ b/src/kbmod_ml/models/cnn.py @@ -15,7 +15,7 @@ @fibad_model class CNN(nn.Module): - def __init__(self, model_config, shape): + def __init__(self, config, shape): logger.info("This is an external model, not in FIBAD!!!") super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) @@ -25,7 +25,7 @@ def __init__(self, model_config, shape): self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) - self.config = model_config + self.config = config # Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)` # but we define them as methods as a way to allow for more flexibility in the future.