Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to reuse .bin files #116

Merged
merged 11 commits into from
Dec 5, 2024
Merged

Add option to reuse .bin files #116

merged 11 commits into from
Dec 5, 2024

Conversation

gitttt-1234
Copy link
Collaborator

@gitttt-1234 gitttt-1234 commented Nov 2, 2024

This PR introduces an option to enable/ disable the auto-deletion of .bin files (data chunks for training) generated by ld.optimie. Additionally, this provides the flexibility to load exitsing train and validation chunks into any training process, by passing the paths to the .bin folder to the ModelTrainer.train() function.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced flexibility in model training with improved handling of input directories and checkpoint paths.
    • Support for multiple learning rate schedulers, allowing for better training configuration.
  • Bug Fixes

    • Improved validation in the Predictor tests to ensure accurate predictions and model loading.
  • Documentation

    • Expanded docstrings to clarify new parameters and functionalities in the model trainer.
  • Tests

    • Added comprehensive test cases for model training and prediction processes, ensuring robustness against various configurations.
  • Chores

    • Updated dependency specification for sleap-io to allow for greater version flexibility.

Copy link
Contributor

coderabbitai bot commented Nov 2, 2024

Walkthrough

The pull request introduces several modifications to the ModelTrainer class in sleap_nn/training/model_trainer.py, enhancing its flexibility in handling training and validation dataset paths. It updates method signatures to include new parameters for checkpoint paths and file handling. The test suites for the Predictor and ModelTrainer classes are also improved with additional assertions and new tests to validate model loading and inference processes. Furthermore, the pyproject.toml file is updated to allow greater flexibility in the sleap-io package versioning.

Changes

File Path Change Summary
sleap_nn/training/model_trainer.py - Added self.trainer and refactored self.bin_files_path initialization.
- Updated _create_data_loaders to accept chunks_dir_path.
- Enhanced train method with new parameters for checkpoint paths and bin file deletion.
tests/inference/test_predictors.py - Enhanced test_topdown_predictor, test_single_instance_predictor, and test_bottomup_predictor with additional assertions for skeleton structure validation.
tests/training/test_model_trainer.py - Updated test_trainer to remove a parameter and added new tests for bin file reuse and trained checkpoint loading.
- Refined assertions and streamlined cleanup operations.
pyproject.toml - Updated sleap-io dependency from ==0.1.10 to >=0.1.10 for greater flexibility.
sleap_nn/inference/utils.py - Simplified get_skeleton_from_config function to create sio.Skeleton objects more efficiently.
tests/inference/test_utils.py - Enhanced test_get_skeleton_from_config with detailed assertions for skeleton structure and properties.

Possibly related PRs

Suggested reviewers

  • talmo

Poem

In the burrow deep, where changes bloom,
The ModelTrainer sheds its gloom.
With paths anew and tests so bright,
We hop along, in pure delight! 🐇✨
For every weight and file we save,
A better model, we shall pave!


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@gitttt-1234 gitttt-1234 changed the base branch from main to divya/load-head-ckpt-inference November 2, 2024 00:13
@gitttt-1234 gitttt-1234 requested a review from talmo November 2, 2024 00:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.

Actionable comments posted: 20

🧹 Outside diff range and nitpick comments (35)
tests/assets/minimal_instance_centroid/initial_config.yaml (1)

78-85: Review learning rate scheduler parameters.

The ReduceLROnPlateau configuration uses very small values:

  • threshold: 1.0e-07 and min_lr: 1.0e-08 are quite low
  • patience: 5 with cooldown: 3 means LR changes can occur frequently

Consider if these values are appropriate for your model and training dynamics. Too-frequent LR changes might destabilize training.

Recommendations:

  1. Consider increasing the threshold to ~1e-4 for more stable training
  2. You might want to increase cooldown to prevent too-frequent LR changes
  3. Document why these specific values were chosen
tests/assets/minimal_instance_bottomup/initial_config.yaml (1)

70-71: Document .bin file deletion behavior.

Since this PR introduces the option to reuse .bin files, it would be helpful to add a comment or parameter that explicitly indicates whether .bin files will be automatically deleted after training. This would make the behavior more transparent to users.

Consider adding a parameter like delete_bin_files: true with appropriate documentation.

tests/assets/minimal_instance_centroid/training_config.yaml (1)

89-96: Consider separating LR scheduler changes into a different PR

The learning rate scheduler changes appear to be unrelated to the main objective of this PR (reusing .bin files). While the configuration looks well-tuned, it might be clearer to handle these changes in a separate PR focused on training optimizations.

The current settings are reasonable:

  • Reduction factor of 0.5
  • Patience of 5 epochs with 3 epochs cooldown
  • Minimum learning rate of 1e-8
tests/assets/minimal_instance_bottomup/training_config.yaml (2)

85-85: Document the bin_files_path usage.

This parameter enables the reuse of .bin files as described in the PR objectives. Consider:

  1. Adding a comment explaining the expected path format
  2. Documenting whether relative paths are supported
  3. Clarifying what happens when the path is invalid

Add a YAML comment above this line:

+  # Path to existing .bin files for reusing training data chunks. Set to null to generate new chunks.
   bin_files_path:

99-106: LGTM! Well-structured learning rate scheduler configuration.

The ReduceLROnPlateau configuration has sensible defaults:

  • Gradual reduction (factor: 0.5)
  • Reasonable patience (5 epochs) and cooldown (3 epochs)
  • Safe minimum learning rate (1e-8)

Consider adding comments to explain the threshold_mode options (abs vs rel) for future maintainers.

tests/assets/minimal_instance/training_config.yaml (2)

5-6: Document the new data configuration parameters.

New parameters have been added without documentation explaining their purpose and impact:

  • user_instances_only
  • chunk_size
  • min_crop_size

Please add documentation describing:

  • What each parameter controls
  • Expected values/ranges
  • Default behavior

Also applies to: 15-15


Line range hint 1-103: Consider splitting changes into focused PRs.

The current changes mix multiple concerns:

  1. .bin file reuse functionality (primary objective)
  2. Learning rate scheduler modifications
  3. Data preprocessing parameters

This makes the changes harder to review and maintain. Consider:

  1. Keeping only the .bin file reuse changes in this PR
  2. Moving LR scheduler and preprocessing changes to separate PRs
  3. Adding comprehensive documentation for all new parameters

Additionally, please add validation to ensure the bin_files_path points to a valid directory when provided.

docs/config_bottomup.yaml (2)

6-6: Add documentation for the chunk_size parameter.

Since this parameter is crucial for the new feature of managing training data chunks, please add a comment explaining its purpose, impact on training, and any constraints on its value.

 data_config:
   provider: LabelsReader
   train_labels_path: minimal_instance.pkg.slp
   val_labels_path: minimal_instance.pkg.slp
   user_instances_only: True
+  # Size of data chunks for training. Controls how many frames are processed
+  # together when generating .bin files
   chunk_size: 100

Line range hint 92-92: Document the bin_files_path parameter.

Since this PR introduces the ability to reuse .bin files, please add documentation explaining:

  1. The purpose of this parameter
  2. Expected path format
  3. Behavior when the path is empty vs. when it contains a value
   save_ckpt: true
   save_ckpt_path: min_inst_bottomup1
+  # Path to existing .bin files for reusing training data chunks
+  # Leave empty to generate new chunks, or specify path to reuse existing ones
   bin_files_path:
docs/config_topdown_centered_instance.yaml (3)

6-6: Document the chunk_size parameter's purpose and impact.

The newly added chunk_size parameter aligns with the PR's objective of managing training data chunks, but its purpose and impact should be documented for users.

Consider adding a comment explaining:

  • What this parameter controls
  • How it affects the training process
  • Any recommended values or constraints

107-114: LGTM! Well-structured learning rate scheduler configuration.

The ReduceLROnPlateau configuration is comprehensive with appropriate parameters for stable training. It pairs well with the early stopping configuration.

Consider adding inline comments explaining the purpose of each parameter, especially threshold_mode: abs vs. rel, to help users customize these values for their needs.


Line range hint 92-93: Add parameter to control .bin files deletion.

The PR aims to allow users to control the automatic deletion of .bin files, but this configuration is missing a parameter to enable/disable this feature.

Consider adding a parameter like retain_bin_files: false near the bin_files_path configuration to control this behavior.

  bin_files_path:
+ retain_bin_files: false  # Controls whether to keep .bin files after training
🧰 Tools
🪛 yamllint

[error] 8-8: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

initial_config.yaml (2)

77-77: Document the bin_files_path parameter.

This new parameter aligns with the PR objective to reuse .bin files, but needs documentation about:

  • Expected directory structure
  • File naming conventions
  • Any requirements for the binary files

Consider adding a comment above this parameter explaining its usage:

+  # Path to directory containing pre-generated .bin training chunks
+  # If provided, these chunks will be reused instead of generating new ones
   bin_files_path: null

73-73: Add warning about GPU determinism.

While setting a fixed seed is good for reproducibility, users should be aware that complete determinism requires additional settings when using GPUs.

Consider adding a warning comment:

+  # Note: For complete determinism with GPUs, additional PyTorch settings are required
   seed: 1000
docs/config_centroid.yaml (2)

6-6: Consider documenting chunk_size parameter and its implications.

While the chunk_size parameter has been added, its relationship to the .bin files and data chunking process should be documented in comments.

 data_config:
   provider: LabelsReader
   train_labels_path: minimal_instance.pkg.slp
   val_labels_path: minimal_instance.pkg.slp
   user_instances_only: True
-  chunk_size: 100
+  # Size of data chunks used for training. Affects the size of generated .bin files
+  chunk_size: 100

Line range hint 89-89: Add configuration for .bin file retention.

The PR aims to add an option to control .bin file deletion, but there's no visible configuration parameter for this feature. Consider adding a parameter to control this behavior.

   save_ckpt: true
   save_ckpt_path: 'min_inst_centroid'
   bin_files_path:
+  # Set to true to retain .bin files after training (default: false)
+  retain_bin_files: false
   resume_ckpt_path:
🧰 Tools
🪛 yamllint

[error] 8-8: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

sleap_nn/data/streaming_datasets.py (2)

154-155: Consider adding input validation for crop dimensions.

Since the crop dimensions are critical for proper functioning and are now calculated once during initialization, consider adding validation to ensure the scaled dimensions remain valid.

Add input validation:

 # Re-crop to original crop size
+if not all(x > 0 for x in self.crop_hw):
+    raise ValueError(f"Invalid crop dimensions after scaling: {self.crop_hw}. Check input_scale: {self.input_scale}")
 self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]

154-155: Document the relationship with .bin files.

Since this PR focuses on .bin file reuse, it would be helpful to document how the crop size affects the data chunks stored in .bin files.

Add documentation:

 # Re-crop to original crop size
+# Note: These dimensions affect the data chunks stored in .bin files.
+# When reusing .bin files, ensure consistent crop_hw and input_scale values
 self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]
docs/config.md (2)

Line range hint 1-188: Documentation missing for .bin files reuse feature.

The PR introduces the ability to reuse .bin files, but the documentation only mentions the bin_files_path parameter for saving them. Please add documentation for:

  • Parameters controlling automatic deletion of .bin files
  • How to specify existing .bin files for reuse in training

Would you like me to help draft the documentation for these new parameters?

🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

170-170: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


171-171: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


172-172: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


173-173: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


174-174: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


175-175: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


176-176: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


177-177: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


189-189: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


190-190: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


191-191: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


178-188: Enhance scheduler configuration documentation with examples.

The scheduler configuration is well-documented but could benefit from practical examples. Consider adding:

  1. Example configurations for both scheduler types
  2. Common use cases for each parameter
  3. Guidelines for choosing between StepLR and ReduceLROnPlateau

Example addition:

Example configurations:
```yaml
# StepLR: Reduce learning rate by half every 10 epochs
lr_scheduler:
  scheduler: "StepLR"
  step_lr:
    step_size: 10
    gamma: 0.5

# ReduceLROnPlateau: Reduce learning rate when validation loss plateaus
lr_scheduler:
  scheduler: "ReduceLROnPlateau"
  reduce_lr_on_plateau:
    patience: 5
    factor: 0.1
    threshold: 1e-4

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 LanguageTool</summary>

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau"`.         - `step_lr`:             - `step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)

---

[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1.         - `reduce_lr_on_plateau`:             - `threshold`: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

</details>
<details>
<summary>🪛 Markdownlint</summary>

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

---

184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

</details>

</details>

</blockquote></details>
<details>
<summary>sleap_nn/inference/predictors.py (3)</summary><blockquote>

`94-95`: **Add parameter validation for checkpoint paths.**

Consider adding validation to ensure that if `head_ckpt_path` is provided, `backbone_ckpt_path` must also be provided, as loading only head weights without backbone weights might lead to unexpected behavior.

```diff
 def from_model_paths(
     cls,
     model_paths: List[Text],
     backbone_ckpt_path: Optional[str] = None,
     head_ckpt_path: Optional[str] = None,
+    # Add validation
+    if head_ckpt_path is not None and backbone_ckpt_path is None:
+        raise ValueError("backbone_ckpt_path must be provided when head_ckpt_path is provided")

1622-1626: Improve parameter documentation in docstring.

The docstring for the new parameters could be more descriptive and include examples:

-        backbone_ckpt_path: (str) To run inference on any `.ckpt` other than `best.ckpt`
-                from the `model_paths` dir, the path to the `.ckpt` file should be passed here.
-        head_ckpt_path: (str) Path to `.ckpt` file if a different set of head layer weights
-                are to be used. If `None`, the `best.ckpt` from `model_paths` dir is used (or the ckpt
-                from `backbone_ckpt_path` if provided.)
+        backbone_ckpt_path: Optional path to a checkpoint file containing backbone weights.
+                If provided, these weights will be used instead of the backbone weights from
+                `best.ckpt`. This allows mixing weights from different checkpoints.
+                Example: "/path/to/backbone_v2.ckpt"
+        head_ckpt_path: Optional path to a checkpoint file containing head layer weights.
+                Can only be used if backbone_ckpt_path is also provided. This enables using
+                different head weights while maintaining the same backbone.
+                Example: "/path/to/head_specialized.ckpt"

596-597: Add error handling for checkpoint loading.

Consider adding try-except blocks when loading checkpoints to handle potential errors gracefully:

  • File not found
  • Invalid checkpoint format
  • Incompatible state dict structure
+def _safe_load_checkpoint(path: str) -> dict:
+    """Safely load a checkpoint file with error handling."""
+    try:
+        return torch.load(path)
+    except FileNotFoundError:
+        raise ValueError(f"Checkpoint file not found: {path}")
+    except Exception as e:
+        raise ValueError(f"Error loading checkpoint {path}: {str(e)}")

Also applies to: 606-607, 610-611

tests/training/test_model_trainer.py (2)

216-218: Simplify the OmegaConf.update call formatting

The OmegaConf.update call can be condensed into a single line for readability, as it doesn't exceed typical line length limits.

Apply this diff to improve formatting:

-OmegaConf.update(
-    config_early_stopping, "trainer_config.lr_scheduler.scheduler", None
-)
+OmegaConf.update(config_early_stopping, "trainer_config.lr_scheduler.scheduler", None)

330-333: Clarify intentional use of invalid scheduler name

The scheduler name "ReduceLR" is likely intended to be invalid to test exception handling. Consider adding a comment to clarify this for future maintainability.

Apply this diff to add a clarifying comment:

+ # Intentionally using an invalid scheduler name to test exception handling
 OmegaConf.update(config, "trainer_config.lr_scheduler.scheduler", "ReduceLR")
 with pytest.raises(ValueError):
     trainer = ModelTrainer(config)
tests/inference/test_predictors.py (3)

3-3: Remove unused import Text

The Text class from the typing module is imported but not used in the code. Removing this unused import will clean up the code.

Apply this diff to remove the unused import:

-from typing import Text
🧰 Tools
🪛 Ruff

3-3: typing.Text imported but unused

Remove unused import: typing.Text

(F401)


692-693: Remove unnecessary debug print statements

The print statements at lines 692-693 and 701-702 seem to be leftover from debugging. Removing these will keep test output clean and focus on relevant information.

Apply this diff to remove the print statements:

-    print(f"head_layer_ckpt: {head_layer_ckpt}")
-    print(model_weights)

Also applies to: 701-702


692-693: Remove unnecessary debug print statements

The print statements at lines 692-693 and 701-702 are likely unintended for committed code. Removing them ensures cleaner test outputs.

Apply this diff:

-    print(f"head_layer_ckpt: {head_layer_ckpt}")
-    print(model_weights)

Also applies to: 701-702

sleap_nn/training/model_trainer.py (7)

376-383: Add docstrings for new parameters in train method

The train method now accepts additional parameters that are not documented. Providing docstrings for these parameters enhances code readability and user understanding.

Apply this diff to add the parameter documentation:

 def train(
     self,
     backbone_trained_ckpts_path: Optional[str] = None,
     head_trained_ckpts_path: Optional[str] = None,
     delete_bin_files_after_training: bool = True,
     train_chunks_dir_path: Optional[str] = None,
     val_chunks_dir_path: Optional[str] = None,
 ):
+    """
+    Initiate the training process.
+
+    Args:
+        backbone_trained_ckpts_path: Path to a trained backbone checkpoint for model initialization.
+        head_trained_ckpts_path: Path to a trained head checkpoint for model initialization.
+        delete_bin_files_after_training: Whether to delete `.bin` files after training. Defaults to True.
+        train_chunks_dir_path: Path to existing training chunks directory. If None, new chunks will be created.
+        val_chunks_dir_path: Path to existing validation chunks directory. If None, new chunks will be created.
+    """
     logger = []

354-358: Add parameter descriptions to the _initialize_model method docstring

The _initialize_model method accepts new parameters that aren't documented. Adding descriptions will improve code clarity.

Apply this diff to add the parameter documentation:

 def _initialize_model(
     self,
     backbone_trained_ckpts_path: Optional[str] = None,
     head_trained_ckpts_path: Optional[str] = None,
 ):
+    """
+    Initialize the model with optional pretrained checkpoints.
+
+    Args:
+        backbone_trained_ckpts_path: Path to a pretrained backbone checkpoint.
+        head_trained_ckpts_path: Path to a pretrained head checkpoint.
+    """
     models = {

528-529: Correct typos in docstring parameter descriptions

In the docstring, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

743-744: Correct typos in docstring parameter descriptions

In the docstring of SingleInstanceModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

824-825: Correct typos in docstring parameter descriptions

In the docstring of TopDownCenteredInstanceModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

905-906: Correct typos in docstring parameter descriptions

In the docstring of CentroidModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

986-987: Correct typos in docstring parameter descriptions

In the docstring of BottomUpModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """
🛑 Comments failed to post (20)
tests/assets/minimal_instance/training_config.yaml (1)

82-82: 💡 Codebase verification

Based on the search results, I can now generate a final response as I found the implementation details in sleap_nn/training/model_trainer.py:

Clarify bin_files_path behavior in configuration documentation

The implementation shows that bin_files_path is used as follows:

  • When null (default), it uses the input directory path
  • When set, creates a timestamped subdirectory under the specified path ({bin_files_path}/chunks_{timestamp})
  • The path can be absolute or relative, and proper directory creation with permissions is handled

Please add these details to the configuration documentation to help users understand:

  • The default behavior (null = use input directory)
  • The automatic timestamped subdirectory creation
  • That both absolute and relative paths are supported
🔗 Analysis chain

Enhance bin_files_path configuration.

The new bin_files_path parameter is crucial for the .bin file reuse feature, but needs more clarity:

  1. Document the expected path format (absolute/relative)
  2. Clarify what happens when the value is null
  3. Consider adding path validation

Let's check if there's any documentation or validation for this parameter:


Let me try another search strategy to find any documentation or validation related to bin_files_path:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for documentation or validation related to bin_files_path
rg -A 5 "bin_files_path"

Length of output: 195


Script:

#!/bin/bash
# Search for bin_files_path in Python files and YAML files
fd -e py -e yaml --exec rg -l "bin_files_path" {}

# Get content of files containing bin_files_path
fd -e py -e yaml --exec rg -B 5 -A 5 "bin_files_path" {}

Length of output: 12827

initial_config.yaml (3)

1-104: ⚠️ Potential issue

Fix newline characters in the file.

The file uses incorrect newline characters which can cause issues across different operating systems.

Convert all line endings to Unix-style (LF) newlines. You can do this using:

  • Git: git config --global core.autocrlf input
  • Editor: Most modern editors can handle this automatically
  • Command line: dos2unix initial_config.yaml
🧰 Tools
🪛 yamllint

[error] 1-1: wrong new line character: expected \n

(new-lines)


84-84: ⚠️ Potential issue

Remove API key field from configuration.

Storing API keys in configuration files, even if empty, is a security risk as it might encourage users to commit sensitive information.

Consider:

  1. Removing the API key field entirely and using environment variables
  2. Adding a warning comment about not committing API keys
-    api_key: ''
+    # Configure API key using WANDB_API_KEY environment variable
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    # Configure API key using WANDB_API_KEY environment variable

3-4: ⚠️ Potential issue

Replace hardcoded Windows paths with platform-independent paths.

The current paths have several issues:

  1. Windows-specific backslashes make it non-portable
  2. Absolute paths are not suitable for version control
  3. Using the same file for both training and validation could lead to overfitting

Consider using:

  1. Relative paths with forward slashes
  2. Different datasets for training and validation
-  train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
-  val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
+  train_labels_path: data/train/minimal_instance.pkg.slp
+  val_labels_path: data/val/minimal_instance.pkg.slp
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  train_labels_path: data/train/minimal_instance.pkg.slp
  val_labels_path: data/val/minimal_instance.pkg.slp
docs/config.md (1)

178-188: ⚠️ Potential issue

Fix markdown formatting issues.

There are several formatting inconsistencies in the scheduler configuration section:

  1. List indentation is inconsistent
  2. Unnecessary spaces in emphasis markers

Apply these formatting fixes:

-        - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
-        - `step_lr`:
-            - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.
+    - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
+    - `step_lr`:
+      - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

sleap_nn/inference/predictors.py (3)

594-617: 🛠️ Refactor suggestion

Refactor duplicated weight loading logic.

The weight loading logic is duplicated between the centroid and confmap models. Consider extracting this into a helper function to improve maintainability and reduce code duplication.

+def _load_model_weights(model, backbone_ckpt_path: Optional[str], head_ckpt_path: Optional[str]) -> None:
+    """Load backbone and head weights into the model.
+    
+    Args:
+        model: The model to load weights into
+        backbone_ckpt_path: Path to backbone checkpoint
+        head_ckpt_path: Path to head checkpoint
+    """
+    if backbone_ckpt_path is not None and head_ckpt_path is not None:
+        print(f"Loading backbone weights from `{backbone_ckpt_path}` ...")
+        ckpt = torch.load(backbone_ckpt_path)
+        ckpt["state_dict"] = {
+            k: v for k, v in ckpt["state_dict"].items()
+            if ".backbone" in k
+        }
+        model.load_state_dict(ckpt["state_dict"], strict=False)
+    elif backbone_ckpt_path is not None:
+        print(f"Loading weights from `{backbone_ckpt_path}` ...")
+        ckpt = torch.load(backbone_ckpt_path)
+        model.load_state_dict(ckpt["state_dict"], strict=False)
+
+    if head_ckpt_path is not None:
+        print(f"Loading head weights from `{head_ckpt_path}` ...")
+        ckpt = torch.load(head_ckpt_path)
+        ckpt["state_dict"] = {
+            k: v for k, v in ckpt["state_dict"].items()
+            if ".head_layers" in k
+        }
+        model.load_state_dict(ckpt["state_dict"], strict=False)

Then use this helper function:

-        if backbone_ckpt_path is not None and head_ckpt_path is not None:
-            print(f"Loading backbone weights from `{backbone_ckpt_path}` ...")
-            ckpt = torch.load(backbone_ckpt_path)
-            ckpt["state_dict"] = {
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
-                if ".backbone" in k
-            }
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
-
-        elif backbone_ckpt_path is not None:
-            print(f"Loading weights from `{backbone_ckpt_path}` ...")
-            ckpt = torch.load(backbone_ckpt_path)
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
-
-        if head_ckpt_path is not None:
-            print(f"Loading head weights from `{head_ckpt_path}` ...")
-            ckpt = torch.load(head_ckpt_path)
-            ckpt["state_dict"] = {
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
-                if ".head_layers" in k
-            }
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
+        _load_model_weights(centroid_model, backbone_ckpt_path, head_ckpt_path)

Also applies to: 636-659

🧰 Tools
🪛 Ruff

599-599: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


614-614: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1349-1372: 🛠️ Refactor suggestion

Apply consistent improvements across all predictor classes.

The same improvements suggested for other predictor classes should be applied here:

  1. Simplify dictionary key checks
  2. Use the shared weight loading helper function

This ensures consistency across the codebase and reduces maintenance overhead.

🧰 Tools
🪛 Ruff

1354-1354: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1369-1369: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


988-1011: 🛠️ Refactor suggestion

Simplify dictionary key checks and apply consistent weight loading pattern.

  1. Simplify the dictionary key checks by removing unnecessary .keys() calls
  2. Use the same helper function suggested for TopDownPredictor to handle weight loading
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
+                k: v for k, v in ckpt["state_dict"].items()
                 if ".backbone" in k

Apply the same _load_model_weights helper function here to maintain consistency and reduce code duplication.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

993-993: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1008-1008: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

tests/training/test_model_trainer.py (2)

338-376: 🛠️ Refactor suggestion

Avoid hardcoded indices when accessing model parameters

Directly accessing model parameters with hardcoded indices like [0, 0, :] can lead to maintenance issues if the model architecture changes. Consider using parameter names or iterating over the parameters for a more robust approach.


379-424: ⚠️ Potential issue

Add assertions to verify reuse of .bin files

The test test_reuse_bin_files sets up for reusing .bin files but lacks assertions to confirm that the files are indeed reused. Adding assertions will strengthen the test by ensuring the files are not regenerated.

Consider adding these assertions:

assert os.path.exists(trainer1.train_input_dir)
assert os.path.exists(trainer1.val_input_dir)
assert os.path.exists(trainer2.train_input_dir)
assert os.path.exists(trainer2.val_input_dir)
assert trainer1.train_input_dir == trainer2.train_input_dir
assert trainer1.val_input_dir == trainer2.val_input_dir
tests/inference/test_predictors.py (5)

188-218: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code block from lines 188 to 218 appears to be duplicated in other test functions (e.g., lines 447-496 and 670-704). Refactoring this repeated code into a helper function would enhance maintainability and reduce code duplication.


447-496: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code from lines 447 to 496 is similar to code in other test sections. Refactoring this repeated logic into a shared helper function can improve code maintainability and readability.


670-704: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code block from lines 670 to 704 is repeated in other tests. Refactoring into a helper function will reduce redundancy and simplify future maintenance.


491-491: 🛠️ Refactor suggestion

Use np.testing.assert_allclose for numerical assertions

For consistency and better error handling, use np.testing.assert_allclose instead of manual assertions when comparing numerical arrays.

Apply this diff:

-assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
+np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)

218-218: 🛠️ Refactor suggestion

Use np.testing.assert_allclose for numerical assertions

Instead of manually checking numerical closeness with assert np.all(np.abs(...)) < tolerance, consider using np.testing.assert_allclose for better readability and error messages.

Apply this diff to improve the assertion:

-assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
+np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
sleap_nn/training/model_trainer.py (5)

380-380: ⚠️ Potential issue

Clarify default behavior of delete_bin_files_after_training parameter

The parameter delete_bin_files_after_training defaults to True, meaning .bin files will be deleted after training. This could be unexpected for users who want to reuse these files. Consider setting the default to False or clearly documenting this behavior.


240-240: ⚠️ Potential issue

Use raise ... from e when re-raising exceptions to preserve traceback

When re-raising an exception, include the original exception using from e to maintain the traceback.

Apply this diff to modify the raise statement:

-            raise Exception(f"Error while creating the `.bin` files... {e}")
+            raise Exception("Error while creating the `.bin` files...") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

                raise Exception("Error while creating the `.bin` files...") from e
🧰 Tools
🪛 Ruff

240-240: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


224-226: ⚠️ Potential issue

Use raise ... from e when re-raising exceptions to preserve traceback

When raising a new exception within an except block, it's best practice to use raise ... from e to maintain the original traceback.

Apply this diff to modify the raise statement:

-            raise OSError(
-                f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}"
-            )
+            raise OSError(
+                f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory."
+            ) from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

                        raise OSError(
                            f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory."
                        ) from e
🧰 Tools
🪛 Ruff

224-226: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


613-613: ⚠️ Potential issue

Simplify dictionary key iteration by removing .keys()

When iterating over a dictionary's keys, you can omit .keys() for simplicity.

Apply this diff to simplify the comprehension:

     ckpt["state_dict"] = {
         k: ckpt["state_dict"][k]
-        for k in ckpt["state_dict"].keys()
+        for k in ckpt["state_dict"]
         if ".backbone" in k
     }

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

613-613: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


624-624: ⚠️ Potential issue

Simplify dictionary key iteration by removing .keys()

When iterating over a dictionary's keys, you can omit .keys() for simplicity.

Apply this diff to simplify the comprehension:

         ckpt["state_dict"] = {
             k: ckpt["state_dict"][k]
-            for k in ckpt["state_dict"].keys()
+            for k in ckpt["state_dict"]
             if ".head_layers" in k
         }

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

624-624: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

Copy link

codecov bot commented Nov 2, 2024

Codecov Report

Attention: Patch coverage is 86.00000% with 7 lines in your changes missing coverage. Please review.

Project coverage is 97.51%. Comparing base (f093ce2) to head (de2d680).
Report is 25 commits behind head on main.

Files with missing lines Patch % Lines
sleap_nn/training/model_trainer.py 88.37% 5 Missing ⚠️
sleap_nn/inference/utils.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #116      +/-   ##
==========================================
+ Coverage   96.64%   97.51%   +0.86%     
==========================================
  Files          23       38      +15     
  Lines        1818     3777    +1959     
==========================================
+ Hits         1757     3683    +1926     
- Misses         61       94      +33     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gitttt-1234 gitttt-1234 force-pushed the divya/load-head-ckpt-inference branch 2 times, most recently from 9a20582 to b843410 Compare December 5, 2024 20:59
@gitttt-1234 gitttt-1234 force-pushed the divya/reuse-bin-files branch from 25da774 to fc1dee5 Compare December 5, 2024 21:16
@gitttt-1234 gitttt-1234 force-pushed the divya/reuse-bin-files branch from fc1dee5 to 4432b42 Compare December 5, 2024 21:19
@gitttt-1234 gitttt-1234 changed the base branch from divya/load-head-ckpt-inference to main December 5, 2024 21:19
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
sleap_nn/inference/utils.py (1)

20-30: LGTM! Well-optimized skeleton creation logic.

The refactoring improves code readability and efficiency by:

  1. Using direct skeleton construction
  2. Employing list comprehensions for edge creation
  3. Simplifying symmetry handling

One suggestion to make the code even more robust:

Consider adding validation for the input config structure:

 def get_skeleton_from_config(skeleton_config: OmegaConf):
+    if not skeleton_config:
+        raise ValueError("Skeleton configuration cannot be empty")
+
     skeletons = []
     for name, skel_cfg in skeleton_config.items():
+        if not hasattr(skel_cfg, "nodes") or not skel_cfg.nodes:
+            raise ValueError(f"Skeleton '{name}' must have nodes defined")
+
         skel = sio.Skeleton(nodes=[n["name"] for n in skel_cfg.nodes], name=name)
sleap_nn/training/model_trainer.py (2)

128-141: Consider using a ternary operator for better readability

The symmetry extraction logic can be simplified using a ternary operator.

-            if skl.symmetries:
-                symm = [list(s.nodes) for s in skl.symmetries]
-            else:
-                symm = None
+            symm = [list(s.nodes) for s in skl.symmetries] if skl.symmetries else None
🧰 Tools
🪛 Ruff (0.8.0)

131-134: Use ternary operator symm = [list(s.nodes) for s in skl.symmetries] if skl.symmetries else None instead of if-else-block

Replace if-else-block with symm = [list(s.nodes) for s in skl.symmetries] if skl.symmetries else None

(SIM108)


Line range hint 187-222: Good practice: Capturing subprocess output

The implementation properly captures both stdout and stderr from the subprocess, which is essential for debugging and logging. Consider adding logging statements for better traceability.

Consider using a logging framework instead of print statements:

import logging

logging.info("Standard Output:\n %s", stdout)
logging.error("Standard Error:\n %s", stderr)
🧰 Tools
🪛 Ruff (0.8.0)

238-240: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


254-254: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 5fc0c8d and de2d680.

📒 Files selected for processing (6)
  • pyproject.toml (1 hunks)
  • sleap_nn/inference/utils.py (1 hunks)
  • sleap_nn/training/model_trainer.py (15 hunks)
  • tests/inference/test_predictors.py (5 hunks)
  • tests/inference/test_utils.py (1 hunks)
  • tests/training/test_model_trainer.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.0)
sleap_nn/training/model_trainer.py

131-134: Use ternary operator symm = [list(s.nodes) for s in skl.symmetries] if skl.symmetries else None instead of if-else-block

Replace if-else-block with symm = [list(s.nodes) for s in skl.symmetries] if skl.symmetries else None

(SIM108)


238-240: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


254-254: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🔇 Additional comments (7)
tests/inference/test_utils.py (1)

13-21: LGTM! Enhanced test coverage with granular assertions.

The changes improve the test by breaking down the skeleton comparison into specific attributes (nodes, edges, symmetries), making it easier to identify issues when tests fail.

pyproject.toml (1)

31-31: Verify compatibility with newer sleap-io versions.

While allowing newer versions provides flexibility, we should ensure that:

  1. The feature to reuse .bin files works consistently across supported versions
  2. There are no breaking changes in newer versions that could affect functionality
✅ Verification successful

Based on the analysis of the codebase, I can now generate the final response:

Version change to >=0.1.10 is safe for .bin file functionality

The sleap-io dependency is used primarily for loading and processing .slp files to generate .bin files during training. The codebase's interaction with sleap-io is limited to basic operations:

  • Loading .slp files using sio.load_slp()
  • Accessing basic properties like videos, instances, and labels

These are core functionalities that are unlikely to have breaking changes in minor version updates. The .bin file generation process is handled internally by the codebase and doesn't rely on complex sleap-io features.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check latest versions and their release dates
curl -s https://pypi.org/pypi/sleap-io/json | jq -r '.releases | keys[] | select(startswith("0.1.")) | . + " " + (.releases[.][0].upload_time)' | sort -V

# Check for any compatibility warnings in the changelog
gh api repos/talmolab/sleap-io/contents/CHANGELOG.md | jq -r '.content' | base64 -d | grep -A 5 "Breaking Changes"

Length of output: 461


Script:

#!/bin/bash
# Let's try a different approach to check the repository for usage patterns and version information

# Search for sleap-io imports and usage
rg -l "import.*sleap.*io" || true
rg "from.*sleap.*io.*import" -A 2 || true

# Look for any version-specific code or compatibility handling
rg "sleap.*io.*version" -i || true

# Check if there are any tests related to .bin files
rg "\.bin" -l || true

# Look for documentation or comments about sleap-io version requirements
rg "sleap.?io.*0\.1\." || true

Length of output: 5785


Script:

#!/bin/bash
# Let's examine the .bin file handling code and documentation to understand version dependencies

# Check documentation about .bin files
cat docs/config.md || true

# Examine the bin file handling implementation
cat sleap_nn/training/get_bin_files.py || true

# Look for any version-specific handling in model trainer
cat sleap_nn/training/model_trainer.py | grep -A 5 -B 5 "\.bin" || true

# Check tests related to bin files
cat tests/training/test_model_trainer.py | grep -A 5 -B 5 "\.bin" || true

Length of output: 37192

sleap_nn/training/model_trainer.py (2)

Line range hint 391-522: LGTM: Well-implemented bin file management

The changes to the train method are well-structured:

  1. Clear documentation of new parameters
  2. Optional deletion of bin files
  3. Proper existence checks before cleanup
  4. Graceful error handling with ignore_errors=True

Line range hint 525-725: LGTM: Flexible model initialization

The TrainingModel class changes provide good flexibility:

  1. Separate checkpoint loading for backbone and head
  2. Clear validation of scheduler types
  3. Proper error messages for invalid configurations
tests/training/test_model_trainer.py (2)

384-429: LGTM: Comprehensive test for bin file reuse

The test thoroughly validates the bin file reuse functionality:

  1. Proper setup of test configuration
  2. Validation of training with and without reuse
  3. Appropriate cleanup of test artifacts

Line range hint 38-44: LGTM: Thorough skeleton structure validation

The added assertions provide comprehensive validation of the skeleton structure:

  1. Node name comparison
  2. Edge count verification
  3. Edge structure validation
  4. Symmetry verification
tests/inference/test_predictors.py (1)

432-438: LGTM: Consistent skeleton validation across tests

The skeleton validation is consistently implemented across different test functions:

  1. test_single_instance_predictor
  2. test_bottomup_predictor
  3. Matches the implementation in test_topdown_predictor

Also applies to: 490-496, 672-678

sleap_nn/training/model_trainer.py Show resolved Hide resolved
@gitttt-1234 gitttt-1234 merged commit 1529209 into main Dec 5, 2024
6 of 7 checks passed
@gitttt-1234 gitttt-1234 deleted the divya/reuse-bin-files branch December 5, 2024 22:13
@coderabbitai coderabbitai bot mentioned this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants