From 52f145c1de7f108e080e98d08d8cbb721f312fc9 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 1 Nov 2024 14:56:19 -0700 Subject: [PATCH 01/11] Fix loading backbone ckpt --- tests/inference/test_predictors.py | 40 ++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index 7b06938e..ba9d5ee1 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -216,6 +216,46 @@ def test_topdown_predictor( ) assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6) + print( + f"centered instance model: ", + predictor.inference_model.instance_peaks.torch_model.model, + ) + + # check loading diff head ckpt for centroid + preprocess_config = { + "is_rgb": False, + "crop_hw": None, + "max_width": None, + "max_height": None, + } + + predictor = Predictor.from_model_paths( + [minimal_instance_centroid_ckpt], + backbone_ckpt_path=Path(minimal_instance_ckpt) / "best.ckpt", + head_ckpt_path=Path(minimal_instance_centroid_ckpt) / "best.ckpt", + peak_threshold=0.03, + max_instances=6, + preprocess_config=OmegaConf.create(preprocess_config), + ) + + print( + f"centroid model: ", predictor.inference_model.centroid_crop.torch_model.model + ) + + ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt") + backbone_ckpt = ckpt["state_dict"][ + "model.backbone.enc.encoder_stack.0.blocks.0.weight" + ][0, 0, :].numpy() + + model_weights = ( + next(predictor.inference_model.centroid_crop.torch_model.model.parameters())[ + 0, 0, : + ] + .detach() + .numpy() + ) + + assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6) # load only backbone and head ckpt as None - centered instance predictor = Predictor.from_model_paths( From ca83182c0def67e920b00985f97ab440b3b4ace0 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 1 Nov 2024 15:26:10 -0700 Subject: [PATCH 02/11] Add more tests for backbone ckpt --- tests/inference/test_predictors.py | 54 ++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index ba9d5ee1..0f453b2d 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -216,11 +216,32 @@ def test_topdown_predictor( ) assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6) - print( - f"centered instance model: ", - predictor.inference_model.instance_peaks.torch_model.model, + + # load only backbone and head ckpt as None - centered instance + predictor = Predictor.from_model_paths( + [minimal_instance_ckpt], + backbone_ckpt_path=Path(minimal_instance_ckpt) / "best.ckpt", + head_ckpt_path=None, + peak_threshold=0.03, + max_instances=6, + preprocess_config=OmegaConf.create(preprocess_config), ) + ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt") + backbone_ckpt = ckpt["state_dict"][ + "model.backbone.enc.encoder_stack.0.blocks.0.weight" + ][0, 0, :].numpy() + + model_weights = ( + next(predictor.inference_model.instance_peaks.torch_model.model.parameters())[ + 0, 0, : + ] + .detach() + .numpy() + ) + + assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6) + # check loading diff head ckpt for centroid preprocess_config = { "is_rgb": False, @@ -238,11 +259,32 @@ def test_topdown_predictor( preprocess_config=OmegaConf.create(preprocess_config), ) - print( - f"centroid model: ", predictor.inference_model.centroid_crop.torch_model.model + ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt") + backbone_ckpt = ckpt["state_dict"][ + "model.backbone.enc.encoder_stack.0.blocks.0.weight" + ][0, 0, :].numpy() + + model_weights = ( + next(predictor.inference_model.centroid_crop.torch_model.model.parameters())[ + 0, 0, : + ] + .detach() + .numpy() ) - ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt") + assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6) + + # load only backbone and head ckpt as None - centroid + predictor = Predictor.from_model_paths( + [minimal_instance_centroid_ckpt], + backbone_ckpt_path=Path(minimal_instance_centroid_ckpt) / "best.ckpt", + head_ckpt_path=None, + peak_threshold=0.03, + max_instances=6, + preprocess_config=OmegaConf.create(preprocess_config), + ) + + ckpt = torch.load(Path(minimal_instance_centroid_ckpt) / "best.ckpt") backbone_ckpt = ckpt["state_dict"][ "model.backbone.enc.encoder_stack.0.blocks.0.weight" ][0, 0, :].numpy() From 96dbb3a7f876a8a6ebb2697b2bb3dc45e31fa579 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 1 Nov 2024 17:07:00 -0700 Subject: [PATCH 03/11] Add option to reuse bin files --- sleap_nn/training/model_trainer.py | 115 +++++++++++++++++---------- tests/training/test_model_trainer.py | 58 +++++++++++--- 2 files changed, 122 insertions(+), 51 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index a243f1fd..4f1363ad 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -90,6 +90,7 @@ def __init__(self, config: OmegaConf): self.model = None self.train_data_loader = None self.val_data_loader = None + self.bin_files_path = None self.crop_hw = -1 # check which head type to choose the model @@ -110,25 +111,8 @@ def __init__(self, config: OmegaConf): f"Cannot create a new folder in {self.dir_path}. Check the permissions to the given Checkpoint directory. \n {e}" ) - self.bin_files_path = self.config.trainer_config.bin_files_path - if self.bin_files_path is None: - self.bin_files_path = self.dir_path - - self.bin_files_path = f"{self.bin_files_path}/chunks_{datetime.strftime(datetime.now(), '%Y%m%d_%H-%M-%S-%f')}" - print(f"`.bin` files are saved in {self.bin_files_path}") - - if not Path(self.bin_files_path).exists(): - try: - Path(self.bin_files_path).mkdir(parents=True, exist_ok=True) - except OSError as e: - raise OSError( - f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}" - ) - OmegaConf.save(config=self.config, f=f"{self.dir_path}/initial_config.yaml") - self.config.trainer_config.saved_bin_files_path = self.bin_files_path - # set seed torch.manual_seed(self.seed) @@ -180,7 +164,11 @@ def __init__(self, config: OmegaConf): else: self.crop_hw = self.crop_hw[0] - def _create_data_loaders(self): + def _create_data_loaders( + self, + train_chunks_dir_path: Optional[str] = None, + val_chunks_dir_path: Optional[str] = None, + ): """Create a DataLoader for train, validation and test sets using the data_config.""" def run_subprocess(): @@ -218,16 +206,50 @@ def run_subprocess(): print("Standard Output:\n", stdout) print("Standard Error:\n", stderr) - try: - run_subprocess() + if train_chunks_dir_path is None: + try: + self.bin_files_path = self.config.trainer_config.bin_files_path + if self.bin_files_path is None: + self.bin_files_path = self.dir_path + + self.bin_files_path = f"{self.bin_files_path}/chunks_{datetime.strftime(datetime.now(), '%Y%m%d_%H-%M-%S-%f')}" + print( + f"New dir is created and `.bin` files are saved in {self.bin_files_path}" + ) + + if not Path(self.bin_files_path).exists(): + try: + Path(self.bin_files_path).mkdir(parents=True, exist_ok=True) + except OSError as e: + raise OSError( + f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}" + ) + + self.config.trainer_config.saved_bin_files_path = self.bin_files_path - except Exception as e: - raise Exception(f"Error while creating the `.bin` files... {e}") + self.train_input_dir = ( + Path(self.bin_files_path) / "train_chunks" + ).as_posix() + self.val_input_dir = ( + Path(self.bin_files_path) / "val_chunks" + ).as_posix() + + run_subprocess() + + except Exception as e: + raise Exception(f"Error while creating the `.bin` files... {e}") + + else: + print( + f"Using `.bin` files from {train_chunks_dir_path} for train dataset and from {val_chunks_dir_path} for val dataset." + ) + self.train_input_dir = train_chunks_dir_path + self.val_input_dir = val_chunks_dir_path if self.model_type == "single_instance": train_dataset = SingleInstanceStreamingDataset( - input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(), + input_dir=self.train_input_dir, shuffle=self.config.trainer_config.train_data_loader.shuffle, apply_aug=self.config.data_config.use_augmentations_train, augmentation_config=self.config.data_config.augmentation_config, @@ -236,7 +258,7 @@ def run_subprocess(): ) val_dataset = SingleInstanceStreamingDataset( - input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(), + input_dir=self.val_input_dir, shuffle=False, apply_aug=False, confmap_head=self.config.model_config.head_configs.single_instance.confmaps, @@ -246,7 +268,7 @@ def run_subprocess(): elif self.model_type == "centered_instance": train_dataset = CenteredInstanceStreamingDataset( - input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(), + input_dir=self.train_input_dir, shuffle=self.config.trainer_config.train_data_loader.shuffle, apply_aug=self.config.data_config.use_augmentations_train, augmentation_config=self.config.data_config.augmentation_config, @@ -257,7 +279,7 @@ def run_subprocess(): ) val_dataset = CenteredInstanceStreamingDataset( - input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(), + input_dir=self.val_input_dir, shuffle=False, apply_aug=False, confmap_head=self.config.model_config.head_configs.centered_instance.confmaps, @@ -268,7 +290,7 @@ def run_subprocess(): elif self.model_type == "centroid": train_dataset = CentroidStreamingDataset( - input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(), + input_dir=self.train_input_dir, shuffle=self.config.trainer_config.train_data_loader.shuffle, apply_aug=self.config.data_config.use_augmentations_train, augmentation_config=self.config.data_config.augmentation_config, @@ -277,7 +299,7 @@ def run_subprocess(): ) val_dataset = CentroidStreamingDataset( - input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(), + input_dir=self.val_input_dir, shuffle=False, apply_aug=False, confmap_head=self.config.model_config.head_configs.centroid.confmaps, @@ -286,7 +308,7 @@ def run_subprocess(): elif self.model_type == "bottomup": train_dataset = BottomUpStreamingDataset( - input_dir=(Path(self.bin_files_path) / "train_chunks").as_posix(), + input_dir=self.train_input_dir, shuffle=self.config.trainer_config.train_data_loader.shuffle, apply_aug=self.config.data_config.use_augmentations_train, augmentation_config=self.config.data_config.augmentation_config, @@ -297,7 +319,7 @@ def run_subprocess(): ) val_dataset = BottomUpStreamingDataset( - input_dir=(Path(self.bin_files_path) / "val_chunks").as_posix(), + input_dir=self.val_input_dir, shuffle=False, apply_aug=False, confmap_head=self.config.model_config.head_configs.bottomup.confmaps, @@ -355,6 +377,9 @@ def train( self, backbone_trained_ckpts_path: Optional[str] = None, head_trained_ckpts_path: Optional[str] = None, + delete_bin_files_after_training: bool = True, + train_chunks_dir_path: Optional[str] = None, + val_chunks_dir_path: Optional[str] = None, ): """Initiate the training by calling the fit method of Trainer.""" logger = [] @@ -416,7 +441,7 @@ def train( # save the configs as yaml in the checkpoint dir OmegaConf.save(config=self.config, f=f"{self.dir_path}/training_config.yaml") - self._create_data_loaders() + self._create_data_loaders(train_chunks_dir_path, val_chunks_dir_path) # save the skeleton in the config self.config["data_config"]["skeletons"] = {} @@ -468,17 +493,23 @@ def train( config=self.config, f=f"{self.dir_path}/training_config.yaml" ) # TODO: (ubuntu test failing (running for > 6hrs) with the below lines) - # print("Deleting training and validation files...") - # if (Path(self.dir_path) / "train_chunks").exists(): - # shutil.rmtree( - # (Path(self.dir_path) / "train_chunks").as_posix(), - # ignore_errors=True, - # ) - # if (Path(self.dir_path) / "val_chunks").exists(): - # shutil.rmtree( - # (Path(self.dir_path) / "val_chunks").as_posix(), - # ignore_errors=True, - # ) + if delete_bin_files_after_training: + print("Deleting training and validation files...") + if (Path(self.train_input_dir)).exists(): + shutil.rmtree( + (Path(self.train_input_dir)).as_posix(), + ignore_errors=True, + ) + if (Path(self.val_input_dir)).exists(): + shutil.rmtree( + (Path(self.val_input_dir)).as_posix(), + ignore_errors=True, + ) + if self.bin_files_path is not None: + shutil.rmtree( + (Path(self.bin_files_path)).as_posix(), + ignore_errors=True, + ) class TrainingModel(L.LightningModule): diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index dd7d3647..99eb2c0b 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -70,7 +70,7 @@ def test_wandb(): reason="Flaky test (The training test runs on Ubuntu for a long time: >6hrs and then fails.)", ) # TODO: Revisit this test later (Failing on ubuntu) -def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): +def test_trainer(config, tmp_path: str): OmegaConf.update(config, "trainer_config.save_ckpt_path", None) model_trainer = ModelTrainer(config) assert model_trainer.dir_path == "." @@ -95,8 +95,6 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert not ( Path(config.trainer_config.save_ckpt_path).joinpath("best.ckpt").exists() ) - shutil.rmtree((Path(model_trainer.bin_files_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.bin_files_path) / "val_chunks").as_posix()) ####### @@ -174,8 +172,6 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): assert abs(df.loc[0, "learning_rate"] - config.trainer_config.optimizer.lr) <= 1e-4 assert not df.val_loss.isnull().all() assert not df.train_loss.isnull().all() - shutil.rmtree((Path(model_trainer.bin_files_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(model_trainer.bin_files_path) / "val_chunks").as_posix()) ####### @@ -199,8 +195,6 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_copy.trainer_config.save_ckpt_path).joinpath("last.ckpt") ) assert checkpoint["epoch"] == 3 - shutil.rmtree((Path(trainer.bin_files_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(trainer.bin_files_path) / "val_chunks").as_posix()) training_config = OmegaConf.load( f"{config_copy.trainer_config.save_ckpt_path}/training_config.yaml" @@ -235,8 +229,6 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str): Path(config_early_stopping.trainer_config.save_ckpt_path).joinpath("last.ckpt") ) assert checkpoint["epoch"] == 1 - shutil.rmtree((Path(trainer.bin_files_path) / "train_chunks").as_posix()) - shutil.rmtree((Path(trainer.bin_files_path) / "val_chunks").as_posix()) ####### @@ -384,6 +376,54 @@ def test_trainer_load_trained_ckpts(config, tmp_path, minimal_instance_ckpt): assert np.all(np.abs(head_layer_ckpt - model_ckpt) < 1e-6) +def test_reuse_bin_files(config, tmp_path: str): + """Test reusing `.bin` files.""" + # Centroid model + centroid_config = config.copy() + head_config = config.model_config.head_configs.centered_instance + OmegaConf.update(centroid_config, "model_config.head_configs.centroid", head_config) + del centroid_config.model_config.head_configs.centered_instance + del centroid_config.model_config.head_configs.centroid["confmaps"].part_names + + OmegaConf.update( + centroid_config, + "trainer_config.save_ckpt_path", + f"{tmp_path}/test_model_trainer/", + ) + + if (Path(centroid_config.trainer_config.save_ckpt_path) / "best.ckpt").exists(): + os.remove( + ( + Path(centroid_config.trainer_config.save_ckpt_path) / "best.ckpt" + ).as_posix() + ) + os.remove( + ( + Path(centroid_config.trainer_config.save_ckpt_path) / "last.ckpt" + ).as_posix() + ) + shutil.rmtree( + ( + Path(centroid_config.trainer_config.save_ckpt_path) / "lightning_logs" + ).as_posix() + ) + + OmegaConf.update(centroid_config, "trainer_config.save_ckpt", True) + OmegaConf.update(centroid_config, "trainer_config.use_wandb", False) + OmegaConf.update(centroid_config, "trainer_config.max_epochs", 1) + OmegaConf.update(centroid_config, "trainer_config.steps_per_epoch", 10) + + # test reusing bin files + trainer1 = ModelTrainer(centroid_config) + trainer1.train(delete_bin_files_after_training=False) + + trainer2 = ModelTrainer(centroid_config) + trainer2.train( + train_chunks_dir_path=trainer1.train_input_dir, + val_chunks_dir_path=trainer1.val_input_dir, + ) + + def test_topdown_centered_instance_model(config, tmp_path: str): # unet From f31b1e349fb82129561878daed5d7688e0ee689a Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Mon, 4 Nov 2024 07:49:58 -0800 Subject: [PATCH 04/11] Skip test on ubuntu --- tests/training/test_model_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 99eb2c0b..8d153656 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -376,6 +376,11 @@ def test_trainer_load_trained_ckpts(config, tmp_path, minimal_instance_ckpt): assert np.all(np.abs(head_layer_ckpt - model_ckpt) < 1e-6) +@pytest.mark.skipif( + sys.platform.startswith("li"), + reason="Flaky test (The training test runs on Ubuntu for a long time: >6hrs and then fails.)", +) +# TODO: Revisit this test later (Failing on ubuntu) def test_reuse_bin_files(config, tmp_path: str): """Test reusing `.bin` files.""" # Centroid model From e94ca1016acd77a16c36094025e214d72ac118f2 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 5 Nov 2024 17:03:20 -0800 Subject: [PATCH 05/11] Fix chunk dir args --- sleap_nn/training/model_trainer.py | 19 ++++++++----------- tests/training/test_model_trainer.py | 3 +-- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 4f1363ad..74532b55 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -166,8 +166,7 @@ def __init__(self, config: OmegaConf): def _create_data_loaders( self, - train_chunks_dir_path: Optional[str] = None, - val_chunks_dir_path: Optional[str] = None, + chunks_dir_path: Optional[str] = None, ): """Create a DataLoader for train, validation and test sets using the data_config.""" @@ -206,7 +205,7 @@ def run_subprocess(): print("Standard Output:\n", stdout) print("Standard Error:\n", stderr) - if train_chunks_dir_path is None: + if chunks_dir_path is None: try: self.bin_files_path = self.config.trainer_config.bin_files_path if self.bin_files_path is None: @@ -240,11 +239,10 @@ def run_subprocess(): raise Exception(f"Error while creating the `.bin` files... {e}") else: - print( - f"Using `.bin` files from {train_chunks_dir_path} for train dataset and from {val_chunks_dir_path} for val dataset." - ) - self.train_input_dir = train_chunks_dir_path - self.val_input_dir = val_chunks_dir_path + print(f"Using `.bin` files from {chunks_dir_path}.") + self.train_input_dir = (Path(chunks_dir_path) / "train_chunks").as_posix() + self.val_input_dir = (Path(chunks_dir_path) / "val_chunks").as_posix() + self.config.trainer_config.saved_bin_files_path = chunks_dir_path if self.model_type == "single_instance": @@ -378,8 +376,7 @@ def train( backbone_trained_ckpts_path: Optional[str] = None, head_trained_ckpts_path: Optional[str] = None, delete_bin_files_after_training: bool = True, - train_chunks_dir_path: Optional[str] = None, - val_chunks_dir_path: Optional[str] = None, + chunks_dir_path: Optional[str] = None, ): """Initiate the training by calling the fit method of Trainer.""" logger = [] @@ -441,7 +438,7 @@ def train( # save the configs as yaml in the checkpoint dir OmegaConf.save(config=self.config, f=f"{self.dir_path}/training_config.yaml") - self._create_data_loaders(train_chunks_dir_path, val_chunks_dir_path) + self._create_data_loaders(chunks_dir_path) # save the skeleton in the config self.config["data_config"]["skeletons"] = {} diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 8d153656..3bb1335a 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -424,8 +424,7 @@ def test_reuse_bin_files(config, tmp_path: str): trainer2 = ModelTrainer(centroid_config) trainer2.train( - train_chunks_dir_path=trainer1.train_input_dir, - val_chunks_dir_path=trainer1.val_input_dir, + chunks_dir_path=(trainer1.train_input_dir).split("train_chunks")[0], ) From 768407c83df5a8cc986be02972c3a862ccb989fd Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 6 Nov 2024 13:20:57 -0800 Subject: [PATCH 06/11] Fix skeleton tests --- tests/inference/test_predictors.py | 34 ++++++++++++++++++++++++++---- tests/inference/test_utils.py | 9 +++++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tests/inference/test_predictors.py b/tests/inference/test_predictors.py index 0f453b2d..90d73579 100644 --- a/tests/inference/test_predictors.py +++ b/tests/inference/test_predictors.py @@ -34,7 +34,15 @@ def test_topdown_predictor( # check if the predicted labels have same video and skeleton as the ground truth labels gt_labels = sio.load_slp(minimal_instance) gt_lf = gt_labels[0] - assert pred_labels.skeletons == gt_labels.skeletons + + skl = pred_labels.skeletons[0] + gt_skl = gt_labels.skeletons[0] + assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes] + assert len(skl.edges) == len(gt_skl.edges) + for a, b in zip(skl.edges, gt_skl.edges): + assert a[0].name == b[0].name and a[1].name == b[1].name + assert skl.symmetries == gt_skl.symmetries + assert lf.frame_idx == gt_lf.frame_idx assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape assert lf.instances[1].numpy().shape == gt_lf.instances[1].numpy().shape @@ -421,7 +429,13 @@ def test_single_instance_predictor( # check if the predicted labels have same video and skeleton as the ground truth labels gt_labels = sio.load_slp(minimal_instance) gt_lf = gt_labels[0] - assert pred_labels.skeletons == gt_labels.skeletons + skl = pred_labels.skeletons[0] + gt_skl = gt_labels.skeletons[0] + assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes] + assert len(skl.edges) == len(gt_skl.edges) + for a, b in zip(skl.edges, gt_skl.edges): + assert a[0].name == b[0].name and a[1].name == b[1].name + assert skl.symmetries == gt_skl.symmetries assert lf.frame_idx == gt_lf.frame_idx assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape @@ -473,7 +487,13 @@ def test_single_instance_predictor( # check if the predicted labels have same skeleton as the GT labels gt_labels = sio.load_slp(minimal_instance) - assert pred_labels.skeletons == gt_labels.skeletons + skl = pred_labels.skeletons[0] + gt_skl = gt_labels.skeletons[0] + assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes] + assert len(skl.edges) == len(gt_skl.edges) + for a, b in zip(skl.edges, gt_skl.edges): + assert a[0].name == b[0].name and a[1].name == b[1].name + assert skl.symmetries == gt_skl.symmetries assert lf.frame_idx == 0 # check if dictionaries are created when make labels is set to False @@ -649,7 +669,13 @@ def test_bottomup_predictor( # check if the predicted labels have same video and skeleton as the ground truth labels gt_labels = sio.load_slp(minimal_instance) gt_lf = gt_labels[0] - assert pred_labels.skeletons == gt_labels.skeletons + skl = pred_labels.skeletons[0] + gt_skl = gt_labels.skeletons[0] + assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes] + assert len(skl.edges) == len(gt_skl.edges) + for a, b in zip(skl.edges, gt_skl.edges): + assert a[0].name == b[0].name and a[1].name == b[1].name + assert skl.symmetries == gt_skl.symmetries assert lf.frame_idx == gt_lf.frame_idx assert lf.instances[0].numpy().shape == gt_lf.instances[0].numpy().shape diff --git a/tests/inference/test_utils.py b/tests/inference/test_utils.py index 91830947..f4c2fe71 100644 --- a/tests/inference/test_utils.py +++ b/tests/inference/test_utils.py @@ -10,8 +10,15 @@ def test_get_skeleton_from_config(minimal_instance, minimal_instance_ckpt): training_config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml") skeleton_config = training_config.data_config.skeletons skeletons = get_skeleton_from_config(skeleton_config) + skl = skeletons[0] labels = sio.load_slp(f"{minimal_instance}") - assert skeletons[0] == labels.skeletons[0] + gt_skl = labels.skeletons[0] + + assert [a.name for a in skl.nodes] == [a.name for a in gt_skl.nodes] + assert len(skl.edges) == len(gt_skl.edges) + for a, b in zip(skl.edges, gt_skl.edges): + assert a[0].name == b[0].name and a[1].name == b[1].name + assert skl.symmetries == gt_skl.symmetries def test_interp1d(): From 836ac6a87664c50b35bb303340c6b2303175e8b2 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 6 Nov 2024 13:31:43 -0800 Subject: [PATCH 07/11] Add docstring --- sleap_nn/training/model_trainer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 74532b55..567b666f 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -91,6 +91,7 @@ def __init__(self, config: OmegaConf): self.train_data_loader = None self.val_data_loader = None self.bin_files_path = None + self.trainer = None self.crop_hw = -1 # check which head type to choose the model @@ -378,7 +379,19 @@ def train( delete_bin_files_after_training: bool = True, chunks_dir_path: Optional[str] = None, ): - """Initiate the training by calling the fit method of Trainer.""" + """Initiate the training by calling the fit method of Trainer. + + Args: + backbone_trained_ckpts_path: Path of the `ckpt` file with which the backbone + is initialized. If `None`, random init is used. + head_trained_ckpts_path: Path of the `ckpt` file with which the head layers + are initialized. If `None`, random init is used. + delete_bin_files_after_training: If `False`, the `bin` files are retained after + training. Else, the `bin` files are deleted. + chunks_dir_path: Path to chunks dir (this dir should contain `train_chunks` + and `val_chunks` folder.). If `None`, `bin` files are generated. + + """ logger = [] if self.config.trainer_config.save_ckpt: @@ -454,7 +467,7 @@ def train( "symmetries": symm, } - trainer = L.Trainer( + self.trainer = L.Trainer( callbacks=callbacks, logger=logger, enable_checkpointing=self.config.trainer_config.save_ckpt, @@ -466,7 +479,7 @@ def train( ) try: - trainer.fit( + self.trainer.fit( self.model, self.train_data_loader, self.val_data_loader, From 5d7f7bbca4a51ee250966a964c437f30096b0e9e Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 6 Nov 2024 13:36:18 -0800 Subject: [PATCH 08/11] Fix saving skeleton --- sleap_nn/training/model_trainer.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 567b666f..87735989 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -125,6 +125,20 @@ def __init__(self, config: OmegaConf): else True ) # TODO: defaults should be handles in config validation. self.skeletons = train_labels.skeletons + # save the skeleton in the config + self.config["data_config"]["skeletons"] = {} + for skl in self.skeletons: + if skl.symmetries: + symm = [list(s.nodes) for s in skl.symmetries] + else: + symm = None + skl_name = skl.name if skl.name is not None else "skeleton-0" + self.config["data_config"]["skeletons"][skl_name] = { + "nodes": skl.nodes, + "edges": skl.edges, + "symmetries": symm, + } + self.max_stride = self.config.model_config.backbone_config.max_stride self.edge_inds = train_labels.skeletons[0].edge_inds self.chunk_size = ( @@ -453,20 +467,6 @@ def train( self._create_data_loaders(chunks_dir_path) - # save the skeleton in the config - self.config["data_config"]["skeletons"] = {} - for skl in self.skeletons: - if skl.symmetries: - symm = [list(s.nodes) for s in skl.symmetries] - else: - symm = None - skl_name = skl.name if skl.name is not None else "skeleton-0" - self.config["data_config"]["skeletons"][skl_name] = { - "nodes": skl.nodes, - "edges": skl.edges, - "symmetries": symm, - } - self.trainer = L.Trainer( callbacks=callbacks, logger=logger, From 096b6d5070f1c7b75a95b9b0293a379d22950cb1 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 7 Nov 2024 15:31:52 -0800 Subject: [PATCH 09/11] Fix sio import --- sleap_nn/inference/utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sleap_nn/inference/utils.py b/sleap_nn/inference/utils.py index fbf8355c..577aaf28 100644 --- a/sleap_nn/inference/utils.py +++ b/sleap_nn/inference/utils.py @@ -20,15 +20,13 @@ def get_skeleton_from_config(skeleton_config: OmegaConf): for name in skeleton_config.keys(): # create `sio.Node` object. - nodes = [ - sio.model.skeleton.Node(n["name"]) for n in skeleton_config[name].nodes - ] + nodes = [sio.Node(n["name"]) for n in skeleton_config[name].nodes] # create `sio.Edge` object. edges = [ - sio.model.skeleton.Edge( - sio.model.skeleton.Node(e["source"]["name"]), - sio.model.skeleton.Node(e["destination"]["name"]), + sio.Edge( + sio.Node(e["source"]["name"]), + sio.Node(e["destination"]["name"]), ) for e in skeleton_config[name].edges ] @@ -38,17 +36,17 @@ def get_skeleton_from_config(skeleton_config: OmegaConf): list_args = [ set( [ - sio.model.skeleton.Node(s[0]["name"]), - sio.model.skeleton.Node(s[1]["name"]), + sio.Node(s[0]["name"]), + sio.Node(s[1]["name"]), ] ) for s in skeleton_config[name].symmetries ] - symmetries = [sio.model.skeleton.Symmetry(x) for x in list_args] + symmetries = [sio.Symmetry(x) for x in list_args] else: symmetries = [] - skeletons.append(sio.model.skeleton.Skeleton(nodes, edges, symmetries, name)) + skeletons.append(sio.Skeleton(nodes, edges, symmetries, name)) return skeletons From 4432b429316bec4d46d599b8dcc5e3123423067f Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Thu, 5 Dec 2024 09:44:21 -0800 Subject: [PATCH 10/11] Fix for skeleton serialization after inference (#117) --- sleap_nn/inference/utils.py | 39 ++++++++++--------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/sleap_nn/inference/utils.py b/sleap_nn/inference/utils.py index 577aaf28..6bf70259 100644 --- a/sleap_nn/inference/utils.py +++ b/sleap_nn/inference/utils.py @@ -17,36 +17,17 @@ def get_skeleton_from_config(skeleton_config: OmegaConf): """ skeletons = [] - for name in skeleton_config.keys(): - - # create `sio.Node` object. - nodes = [sio.Node(n["name"]) for n in skeleton_config[name].nodes] - - # create `sio.Edge` object. - edges = [ - sio.Edge( - sio.Node(e["source"]["name"]), - sio.Node(e["destination"]["name"]), - ) - for e in skeleton_config[name].edges - ] - - # create `sio.Symmetry` object. - if skeleton_config[name].symmetries: - list_args = [ - set( - [ - sio.Node(s[0]["name"]), - sio.Node(s[1]["name"]), - ] - ) - for s in skeleton_config[name].symmetries - ] - symmetries = [sio.Symmetry(x) for x in list_args] - else: - symmetries = [] + for name, skel_cfg in skeleton_config.items(): + + skel = sio.Skeleton(nodes=[n["name"] for n in skel_cfg.nodes], name=name) + skel.add_edges( + [(e["source"]["name"], e["destination"]["name"]) for e in skel_cfg.edges] + ) + if skel_cfg.symmetries: + for n1, n2 in skel_cfg.symmetries: + skel.add_symmetry(n1["name"], n2["name"]) - skeletons.append(sio.Skeleton(nodes, edges, symmetries, name)) + skeletons.append(skel) return skeletons From de2d680bad595e4bce442a20b93aeb33e02974c1 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 5 Dec 2024 13:23:09 -0800 Subject: [PATCH 11/11] Fix sleap-io version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 85b32965..7037a4e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "av", "kornia", "hydra-core", - "sleap-io==0.1.10", + "sleap-io>=0.1.10", ] dynamic = ["version", "readme"]