diff --git a/docs/config.md b/docs/config.md index d6f3cdd6..a4781dba 100644 --- a/docs/config.md +++ b/docs/config.md @@ -167,6 +167,7 @@ The config file has three main sections: - `api_key`: (str) API key. The API key is masked when saved to config files. - `wandb_mode`: (str) "offline" if only local logging is required. Default: "None". - `prv_runid`: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`. + - `group`: (str) Group name for the run. - `optimizer_name`: (str) Optimizer to be used. One of ["Adam", "AdamW"]. - `optimizer` - `lr`: (float) Learning rate of type float. *Default*: 1e-3 diff --git a/docs/config_centroid.yaml b/docs/config_centroid.yaml index defbf9b8..2c454daa 100644 --- a/docs/config_centroid.yaml +++ b/docs/config_centroid.yaml @@ -120,6 +120,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/initial_config.yaml b/initial_config.yaml deleted file mode 100644 index e350f234..00000000 --- a/initial_config.yaml +++ /dev/null @@ -1,104 +0,0 @@ -data_config: - provider: LabelsReader - train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp - val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp - user_instances_only: true - chunk_size: 100 - preprocessing: - is_rgb: false - max_width: null - max_height: null - scale: 1.0 - crop_hw: - - 160 - - 160 - min_crop_size: null - use_augmentations_train: true - augmentation_config: - intensity: - contrast_p: 1.0 - geometric: - rotation: 180.0 - scale: null - translate_width: 0 - translate_height: 0 - affine_p: 0.5 -model_config: - init_weights: default - pre_trained_weights: null - backbone_type: unet - backbone_config: - in_channels: 1 - kernel_size: 3 - filters: 16 - filters_rate: 1.5 - max_stride: 8 - convs_per_block: 2 - stacks: 1 - stem_stride: null - middle_block: true - up_interpolate: false - head_configs: - single_instance: null - centroid: null - bottomup: null - centered_instance: - confmaps: - part_names: - - '0' - - '1' - anchor_part: 1 - sigma: 1.5 - output_stride: 2 -trainer_config: - train_data_loader: - batch_size: 1 - shuffle: true - num_workers: 2 - val_data_loader: - batch_size: 1 - num_workers: 0 - model_ckpt: - save_top_k: 1 - save_last: true - early_stopping: - stop_training_on_plateau: true - min_delta: 1.0e-08 - patience: 20 - trainer_devices: 1 - trainer_accelerator: cpu - enable_progress_bar: false - steps_per_epoch: null - max_epochs: 2 - seed: 1000 - use_wandb: false - save_ckpt: false - save_ckpt_path: null - bin_files_path: null - resume_ckpt_path: null - wandb: - entity: null - project: test - name: test_run - wandb_mode: offline - api_key: '' - prv_runid: null - log_params: - - trainer_config.optimizer_name - - trainer_config.optimizer.amsgrad - - trainer_config.optimizer.lr - - model_config.backbone_type - - model_config.init_weights - optimizer_name: Adam - optimizer: - lr: 0.0001 - amsgrad: false - lr_scheduler: - scheduler: ReduceLROnPlateau - reduce_lr_on_plateau: - threshold: 1.0e-07 - threshold_mode: rel - cooldown: 3 - patience: 5 - factor: 0.5 - min_lr: 1.0e-08 diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 32787b06..8e557d5b 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -639,6 +639,7 @@ def train( name=wandb_config.name, save_dir=self.dir_path, id=self.config.trainer_config.wandb.prv_runid, + group=self.config.trainer_config.wandb.group, ) logger.append(wandb_logger) diff --git a/tests/assets/minimal_instance/initial_config.yaml b/tests/assets/minimal_instance/initial_config.yaml index b982e3e0..c3ed620b 100755 --- a/tests/assets/minimal_instance/initial_config.yaml +++ b/tests/assets/minimal_instance/initial_config.yaml @@ -75,6 +75,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/assets/minimal_instance/training_config.yaml b/tests/assets/minimal_instance/training_config.yaml index ac22d62d..d9b93502 100755 --- a/tests/assets/minimal_instance/training_config.yaml +++ b/tests/assets/minimal_instance/training_config.yaml @@ -88,6 +88,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: null + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/assets/minimal_instance_bottomup/initial_config.yaml b/tests/assets/minimal_instance_bottomup/initial_config.yaml index ac12e6a2..e130bdfd 100755 --- a/tests/assets/minimal_instance_bottomup/initial_config.yaml +++ b/tests/assets/minimal_instance_bottomup/initial_config.yaml @@ -76,6 +76,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/assets/minimal_instance_bottomup/training_config.yaml b/tests/assets/minimal_instance_bottomup/training_config.yaml index 5111f65d..81587559 100755 --- a/tests/assets/minimal_instance_bottomup/training_config.yaml +++ b/tests/assets/minimal_instance_bottomup/training_config.yaml @@ -91,6 +91,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/assets/minimal_instance_centroid/initial_config.yaml b/tests/assets/minimal_instance_centroid/initial_config.yaml index f080352b..bd575a5a 100755 --- a/tests/assets/minimal_instance_centroid/initial_config.yaml +++ b/tests/assets/minimal_instance_centroid/initial_config.yaml @@ -70,6 +70,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/assets/minimal_instance_centroid/training_config.yaml b/tests/assets/minimal_instance_centroid/training_config.yaml index 04eb2642..f0edf4ab 100755 --- a/tests/assets/minimal_instance_centroid/training_config.yaml +++ b/tests/assets/minimal_instance_centroid/training_config.yaml @@ -81,6 +81,7 @@ trainer_config: wandb_mode: '' api_key: '' prv_runid: + group: optimizer_name: Adam optimizer: lr: 0.0001 diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index a9e11573..1256848f 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -145,13 +145,7 @@ def config(sleap_data_dir): "wandb_mode": "offline", "api_key": "", "prv_runid": None, - "log_params": [ - "trainer_config.optimizer_name", - "trainer_config.optimizer.amsgrad", - "trainer_config.optimizer.lr", - "model_config.backbone_type", - "model_config.init_weights", - ], + "group": None, }, "optimizer_name": "Adam", "optimizer": {"lr": 0.0001, "amsgrad": False},