Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to load trained weights for Head layers #114

Merged
merged 1 commit into from
Dec 5, 2024

Conversation

gitttt-1234
Copy link
Collaborator

@gitttt-1234 gitttt-1234 commented Oct 31, 2024

This PR adds an option to load trained weights for head layers right before training (instead of random initialization) for transfer learning. The .ckpt paths can be passed to ModelTrainer.train(backbone_trained_ckpts_path=<your_path>, head_trained_ckpts_path=<your_path>).

Summary by CodeRabbit

  • New Features

    • Enhanced model training with separate weight handling for backbone and head components.
    • Updated configuration settings for improved data processing and model training.
  • Bug Fixes

    • Improved error handling for loading model weights.
  • Tests

    • Introduced a new test for verifying loading of trained weights for backbone and head layers.
    • Updated existing tests to align with new checkpoint loading logic.

Copy link
Contributor

coderabbitai bot commented Oct 31, 2024

Walkthrough

The changes in this pull request involve modifications to the ModelTrainer and TrainingModel classes in the sleap_nn/training/model_trainer.py file. The updates primarily focus on adjusting method signatures to accept separate parameters for backbone and head model weights, enhancing the weight initialization logic. Additionally, the configuration settings in tests/fixtures/datasets.py have been restructured, and a new test has been added to tests/training/test_model_trainer.py to verify the loading of these weights. Overall, the changes aim to improve the model's weight handling and testing framework.

Changes

File Change Summary
sleap_nn/training/model_trainer.py Updated method signatures in ModelTrainer and TrainingModel to accept separate paths for backbone and head weights. Enhanced weight loading logic and error handling.
tests/fixtures/datasets.py Modified configuration settings: updated provider, added user_instances_only and chunk_size, introduced min_crop_size, and restructured lr_scheduler.
tests/training/test_model_trainer.py Added test_trainer_load_trained_ckpts to verify loading of backbone and head weights. Adjusted existing test logic and removed redundant calls.

Possibly related PRs

  • Add function to load trained weights for backbone model #95: This PR adds functionality to load trained weights specifically for the backbone of a model, which is directly related to the changes in the main PR that modify the _initialize_model and train methods to handle separate paths for backbone and head weights.
  • Refactor model pipeline #51: This PR refactors the model pipeline, which includes changes to the model initialization process that may relate to the modifications in the main PR regarding how models are initialized and weights are loaded.
  • LitData Refactor PR1: Get individual functions for data pipelines #90: This PR introduces new parameters in the data_config section, including user_instances_only and chunk_size, which are relevant to the overall training configuration and may interact with the changes in the main PR regarding model training.
  • Add StepLR scheduler #109: This PR adds a new StepLR scheduler, which enhances the training configuration and may be relevant to the changes in the main PR that involve model training and configuration adjustments.

Suggested reviewers

  • talmo

🐇 In the world of code, we hop and play,
New paths for weights, come what may!
With models trained, and tests so bright,
We load them up, ready for flight!
A leap in logic, a jump in fun,
Together we code, till the day is done! 🌟


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@gitttt-1234 gitttt-1234 changed the base branch from main to divya/add-config-wandb October 31, 2024 13:49
Copy link

codecov bot commented Oct 31, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.37%. Comparing base (f093ce2) to head (751559a).
Report is 23 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #114      +/-   ##
==========================================
+ Coverage   96.64%   97.37%   +0.73%     
==========================================
  Files          23       38      +15     
  Lines        1818     3702    +1884     
==========================================
+ Hits         1757     3605    +1848     
- Misses         61       97      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.

Actionable comments posted: 15

🧹 Outside diff range and nitpick comments (19)
tests/assets/minimal_instance_centroid/initial_config.yaml (2)

64-64: Document the bin_files_path parameter.

Given that this PR introduces functionality for loading pre-trained weights:

  1. Consider adding a comment explaining the expected format and purpose of bin_files_path
  2. Should there be validation to ensure the path exists when provided?
  3. How does this relate to the backbone_trained_ckpts_path and head_trained_ckpts_path mentioned in the PR description?

78-85: LGTM! Consider adding parameter descriptions.

The ReduceLROnPlateau configuration looks well-structured with reasonable default values. To improve maintainability, consider adding inline comments explaining each parameter's purpose and impact on training.

Example documentation:

reduce_lr_on_plateau:
  # Minimum change in loss to be considered an improvement
  threshold: 1.0e-07
  # How to measure the threshold (abs or rel)
  threshold_mode: abs
  # Number of epochs to wait after a reduction before resuming normal operation
  cooldown: 3
  # Number of epochs with no improvement after which LR will be reduced
  patience: 5
  # Factor by which the learning rate will be reduced
  factor: 0.5
  # Lower bound on the learning rate
  min_lr: 1.0e-08
tests/assets/minimal_instance/initial_config.yaml (1)

13-14: Document the significance of crop_hw values.

The hardcoded values [160, 160] for crop_hw should be documented to explain their significance and any constraints.

    crop_hw:
-      - 160
-      - 160
+      - 160  # Height of the crop window
+      - 160  # Width of the crop window
tests/assets/minimal_instance_bottomup/initial_config.yaml (2)

70-71: Document the purpose and format of bin_files_path.

The PR mentions loading .ckpt format weights, but bin_files_path suggests binary files. Please clarify:

  1. The expected file format
  2. Whether this is related to the pre-trained weights loading feature
  3. Add documentation for this parameter's usage

84-91: LGTM! Consider adding parameter descriptions.

The learning rate scheduler configuration is well-structured with appropriate values for ReduceLROnPlateau. Consider adding inline comments or documentation explaining the purpose of each parameter for future maintainability.

lr_scheduler:
  scheduler: ReduceLROnPlateau
  reduce_lr_on_plateau:
    threshold: 1.0e-07      # Minimum change in loss to be considered an improvement
    threshold_mode: abs     # Use absolute change in loss
    cooldown: 3            # Epochs to wait before resuming normal operation
    patience: 5            # Epochs to wait before reducing LR
    factor: 0.5           # Factor to reduce LR by
    min_lr: 1.0e-08       # Minimum LR threshold
tests/assets/minimal_instance_centroid/training_config.yaml (1)

75-75: Document the bin_files_path parameter.

Since this parameter is related to loading trained weights (PR objective), please:

  1. Document the expected path format and structure
  2. Add validation to handle empty/invalid paths
  3. Consider providing an example path in the config
tests/assets/minimal_instance_bottomup/training_config.yaml (1)

99-106: LGTM: Well-configured learning rate scheduler.

The ReduceLROnPlateau configuration is comprehensive and uses sensible values. The scheduler will:

  1. Reduce learning rate by half when no improvement is seen
  2. Wait 5 epochs before reducing (patience)
  3. Allow 3 epochs of cooldown between reductions
  4. Stop at min_lr of 1e-8

Consider adding comments in the YAML to explain the purpose of less obvious parameters like threshold_mode and cooldown.

docs/config_bottomup.yaml (2)

6-6: Document the purpose of chunk_size parameter.

The newly added chunk_size parameter would benefit from a comment explaining its purpose and impact on data processing. This helps users understand how to tune this value for their specific use cases.


Parameters should be added to trainer_config section

Based on the implementation in model_trainer.py, the parameters backbone_trained_ckpts_path and head_trained_ckpts_path are used during model initialization and training. These parameters should be added to the trainer_config section of the configuration file, not the model_config section, as they are handled by the trainer.

trainer_config:
  # ... existing config ...
  resume_ckpt_path:
+ backbone_trained_ckpts_path:  # Path to trained checkpoint for backbone
+ head_trained_ckpts_path:      # Path to trained checkpoint for head layers
  optimizer_name: Adam
🔗 Analysis chain

Line range hint 1-115: Missing configuration parameters for trained weights loading.

According to the PR objectives, this change should add support for loading pre-trained weights for head layers. However, I don't see the new parameters backbone_trained_ckpts_path and head_trained_ckpts_path in the configuration. These parameters should be added to enable the described functionality.

Let's check if these parameters are defined elsewhere:

Consider adding these parameters to the configuration file:

trainer_config:
  # ... existing config ...
  resume_ckpt_path:
+ backbone_trained_ckpts_path:
+ head_trained_ckpts_path:
  optimizer_name: Adam
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for the new parameters in other configuration files
rg -l "backbone_trained_ckpts_path|head_trained_ckpts_path"

Length of output: 133


Script:

#!/bin/bash
# Let's check the implementation to understand how these parameters are used
rg -A 10 "backbone_trained_ckpts_path|head_trained_ckpts_path" tests/training/test_model_trainer.py sleap_nn/training/model_trainer.py

Length of output: 17142

docs/config_topdown_centered_instance.yaml (1)

Line range hint 1-114: Missing configuration for trained weights loading.

Based on the PR objectives, this configuration file should include parameters for specifying trained weight paths:

  • backbone_trained_ckpts_path
  • head_trained_ckpts_path

These parameters are essential for the PR's core functionality of loading pre-trained weights for head layers.

Consider adding these parameters under the appropriate section (possibly under model_config or trainer_config).

docs/config_centroid.yaml (1)

Line range hint 89-93: Add configuration for head layer pre-trained weights.

The PR aims to add support for loading pre-trained weights for head layers, but the configuration is missing the necessary parameters. Consider adding head_trained_ckpts_path under the head_configs section to align with the PR's objective.

Example addition:

   head_configs:
     single_instance:
     centered_instance:
     bottomup: 
     centroid:
+      trained_weights_path:  # Path to pre-trained weights for centroid head
       confmaps:
         anchor_part: 0
🧰 Tools
🪛 yamllint

[error] 8-8: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

tests/fixtures/datasets.py (2)

62-62: Document the purpose of min_crop_size parameter.

The new min_crop_size parameter has been added to preprocessing configuration. Please add documentation explaining:

  • The purpose of this parameter
  • The impact when set to None
  • Valid values and their effects

159-167: LGTM! Consider documenting scheduler parameters.

The restructured LR scheduler configuration with ReduceLROnPlateau is well-organized and provides comprehensive control over learning rate adjustments. The parameters are sensibly configured with:

  • Relative threshold mode with small threshold (1e-07)
  • Conservative reduction factor (0.5)
  • Reasonable patience (5) and cooldown (3) periods
  • Safety floor for minimum learning rate (1e-08)

Consider adding inline comments or documentation explaining the purpose of each parameter and their recommended ranges.

tests/training/test_model_trainer.py (2)

108-110: Consider adding edge cases for StepLR parameters.

While the basic StepLR configuration is tested, consider adding test cases for:

  • Edge case values for step_size (e.g., 0, negative)
  • Edge case values for gamma (e.g., 0, negative, >1)
  • Verification of learning rate changes after each step

337-342: Enhance scheduler type validation tests.

Consider improving the error handling tests:

  1. Test multiple invalid scheduler types
  2. Verify the error message content
  3. Test edge cases (empty string, None, etc.)

Example:

invalid_schedulers = ["ReduceLR", "", None, "CustomScheduler"]
for scheduler in invalid_schedulers:
    OmegaConf.update(config, "trainer_config.lr_scheduler.scheduler", scheduler)
    with pytest.raises(ValueError, match=f"Unsupported scheduler type: {scheduler}"):
        trainer = ModelTrainer(config)
        trainer.train()
docs/config.md (1)

178-188: Technical review of learning rate scheduler documentation

The learning rate scheduler documentation is well-structured but could benefit from additional clarity:

  1. For StepLR, consider adding an example to illustrate the decay pattern.
  2. For ReduceLROnPlateau, the threshold_mode parameter's explanation could be clearer.

Consider updating the documentation with this improved explanation:

 - `threshold_mode`: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. *Default*: "rel".
+ - `threshold_mode`: (str) One of "rel" (relative) or "abs" (absolute). Determines how to interpret the threshold parameter:
+   - In "rel" mode: Triggers when improvement is less than (best * threshold) for min mode, or (best * (1 + threshold)) for max mode
+   - In "abs" mode: Triggers when improvement is less than (best - threshold) for min mode, or (best + threshold) for max mode
+   *Default*: "rel"
🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

sleap_nn/training/model_trainer.py (3)

497-498: Include parameter types in docstrings

In the docstring for TrainingModel, you have added descriptions for backbone_trained_ckpts_path and head_trained_ckpts_path. Consider specifying the parameter types for better clarity.

Apply this diff to enhance the docstring:

             head_trained_ckpts_path: Path to trained ckpts for head layer.
+            
+        Args:
+            backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone.
+            head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.

713-714: Add parameter types in docstrings

For consistency and clarity, include the parameter types in the docstrings for SingleInstanceModel.

Apply this diff to enhance the docstring:

             head_trained_ckpts_path: Path to trained ckpts for head layer.
+
+        Args:
+            backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone.
+            head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.

882-895: Update docstrings and parameter types for CentroidModel

Include the parameter types in the docstrings to improve clarity.

Apply this diff:

             head_trained_ckpts_path: Path to trained ckpts for head layer.
+
+        Args:
+            backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone.
+            head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.
🛑 Comments failed to post (15)
tests/assets/minimal_instance/initial_config.yaml (1)

15-15: 💡 Codebase verification

⚠️ Potential issue

min_crop_size parameter should have an integer value

The min_crop_size parameter is used to specify the minimum crop size when crop_hw is not set. Based on the implementation in sleap_nn/data/instance_cropping.py, it expects an integer value or None. Please either:

  • Set an appropriate integer value for min_crop_size
  • Remove the parameter if you want to use the default behavior (equivalent to setting it to 0)
🔗 Analysis chain

Specify a value for min_crop_size or remove if not needed.

The min_crop_size parameter is added but left empty, which could lead to validation issues or undefined behavior.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if min_crop_size is used in the codebase
rg -l "min_crop_size" --type python

Length of output: 161


Script:

#!/bin/bash
# Check if min_crop_size is used in the codebase with ripgrep
rg -l "min_crop_size"

# Also search for any YAML files that might reference this parameter
fd -e yaml -x grep -l "min_crop_size" {}

Length of output: 586


Script:

#!/bin/bash
# Check the implementation and usage of min_crop_size in the core code
rg "min_crop_size" sleap_nn/data/instance_cropping.py -C 2

# Check documentation for min_crop_size
rg "min_crop_size" docs/config.md -C 2

Length of output: 2095

tests/assets/minimal_instance_centroid/training_config.yaml (2)

89-96: 🛠️ Refactor suggestion

Review learning rate scheduler parameters.

The ReduceLROnPlateau configuration needs attention:

  1. threshold: 1e-7 is very small - consider a larger value like 1e-4
  2. patience: 5 epochs might be too short for complex models
  3. min_lr: 1e-8 is extremely small - typical values are around 1e-6

Consider these adjustments:

    reduce_lr_on_plateau:
-     threshold: 1.0e-07
+     threshold: 1.0e-04
-     patience: 5
+     patience: 10
-     min_lr: 1.0e-08
+     min_lr: 1.0e-06
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    scheduler: ReduceLROnPlateau
    reduce_lr_on_plateau:
      threshold: 1.0e-04
      threshold_mode: abs
      cooldown: 3
      patience: 10
      factor: 0.5
      min_lr: 1.0e-06

75-77: ⚠️ Potential issue

Implementation incomplete for PR objective.

The configuration is missing the parameters mentioned in the PR objective:

  • backbone_trained_ckpts_path
  • head_trained_ckpts_path

These parameters are essential for loading pre-trained weights for head layers.

Add the missing parameters:

  bin_files_path:
+ backbone_trained_ckpts_path:
+ head_trained_ckpts_path:
  resume_ckpt_path:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  bin_files_path:
  backbone_trained_ckpts_path:
  head_trained_ckpts_path:
  resume_ckpt_path:
  wandb:
docs/config_topdown_centered_instance.yaml (1)

15-15: ⚠️ Potential issue

Specify a value for min_crop_size or document if it's optional.

The min_crop_size parameter is added without a value. This could lead to:

  1. Configuration validation errors
  2. Unexpected runtime behavior

Please either:

  • Set a default value
  • Document if this is an optional parameter
  • Add validation to handle empty values
initial_config.yaml (4)

1-104: ⚠️ Potential issue

Fix line endings for cross-platform compatibility

The file uses incorrect line endings. Ensure consistent use of Unix-style line endings (\n) across the codebase.

Run the following command to fix line endings:

#!/bin/bash
# Convert CRLF to LF
sed -i 's/\r$//' initial_config.yaml
🧰 Tools
🪛 yamllint

[error] 1-1: wrong new line character: expected \n

(new-lines)


84-84: ⚠️ Potential issue

Security: Remove empty API key field

For security reasons, sensitive information like API keys should not be included in configuration files, even if empty. Consider removing this field and using environment variables instead.

-    api_key: ''
+    # API key should be provided via environment variable WANDB_API_KEY
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    # API key should be provided via environment variable WANDB_API_KEY

3-4: ⚠️ Potential issue

Replace hardcoded paths with relative paths

The configuration contains Windows-specific absolute paths, which will break when used on different machines or operating systems.

Replace absolute paths with relative paths:

-  train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
-  val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
+  train_labels_path: tests/assets/minimal_instance.pkg.slp
+  val_labels_path: tests/assets/minimal_instance.pkg.slp
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  train_labels_path: tests/assets/minimal_instance.pkg.slp
  val_labels_path: tests/assets/minimal_instance.pkg.slp

27-28: ⚠️ Potential issue

Add configuration for backbone and head trained weights

According to the PR objectives, the configuration should support backbone_trained_ckpts_path and head_trained_ckpts_path for loading pre-trained weights. However, these parameters are missing from the configuration.

Add the missing parameters:

 init_weights: default
 pre_trained_weights: null
+backbone_trained_ckpts_path: null
+head_trained_ckpts_path: null
 backbone_type: unet
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  init_weights: default
  pre_trained_weights: null
  backbone_trained_ckpts_path: null
  head_trained_ckpts_path: null
tests/training/test_model_trainer.py (1)

346-384: 🛠️ Refactor suggestion

Enhance test coverage for weight loading functionality.

The current test verifies basic weight loading, but consider adding tests for:

  1. Loading only backbone weights
  2. Loading only head weights
  3. Invalid checkpoint paths
  4. Incompatible model architectures
  5. Verifying weights are maintained after training

Example additions:

# Test loading only backbone weights
trainer = ModelTrainer(load_weights_config)
trainer._initialize_model(
    backbone_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),
    head_trained_ckpts_path=None
)
model_ckpt = next(trainer.model.parameters())[0, 0, :].detach().numpy()
assert np.all(np.abs(first_layer_ckpt - model_ckpt) < 1e-6)

# Test invalid checkpoint path
with pytest.raises(FileNotFoundError):
    trainer._initialize_model(
        backbone_trained_ckpts_path="invalid_path.ckpt",
        head_trained_ckpts_path=None
    )

# Test incompatible architecture
incompatible_ckpt = {...}  # Create checkpoint with different architecture
with pytest.raises(ValueError, match="Incompatible model architecture"):
    trainer._initialize_model(
        backbone_trained_ckpts_path=incompatible_ckpt,
        head_trained_ckpts_path=None
    )
docs/config.md (1)

178-188: ⚠️ Potential issue

Markdown formatting improvements needed

The indentation of list items is inconsistent with the rest of the document. The static analysis tool suggests using 4-space indentation for unordered lists.

Apply these formatting changes:

  1. Adjust indentation to match the rest of the document (4 spaces for top-level items, 6 spaces for nested items)
  2. Remove extra spaces around emphasis markers in the threshold_mode description
  3. Fix loose punctuation marks after the scheduler options
-        - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
-        - `step_lr`:
-            - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.
+    - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`
+    - `step_lr`:
+      - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

sleap_nn/training/model_trainer.py (5)

332-336: 🛠️ Refactor suggestion

Update docstrings to reflect new parameters

The _initialize_model method now accepts two new parameters: backbone_trained_ckpts_path and head_trained_ckpts_path. Please update the method's docstring to include descriptions for these parameters to maintain clarity and consistency in documentation.


455-455: ⚠️ Potential issue

Avoid logging sensitive configuration data

When updating the WandB experiment configuration, ensure that no sensitive information (such as API keys or personal data) is being inadvertently logged.


517-518: 🛠️ Refactor suggestion

Remove redundancy in attribute assignment

You are already passing backbone_trained_ckpts_path and head_trained_ckpts_path to the superclass. Assigning them again here may be unnecessary.

Apply this diff to remove redundant assignments:

-        self.backbone_trained_ckpts_path = backbone_trained_ckpts_path
-        self.head_trained_ckpts_path = head_trained_ckpts_path

Committable suggestion was skipped due to low confidence.


576-585: ⚠️ Potential issue

Simplify dictionary key iteration

As per the static analysis hint, you can simplify the dictionary iteration by removing the .keys() method.

Apply this diff to address the issue:

-            for k in ckpt["state_dict"].keys()
+            for k in ckpt["state_dict"]
             if ".backbone" in k
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        # Initializing backbone (encoder + decoder) with trained ckpts
        if backbone_trained_ckpts_path is not None:
            print(f"Loading backbone weights from `{backbone_trained_ckpts_path}` ...")
            ckpt = torch.load(backbone_trained_ckpts_path)
            ckpt["state_dict"] = {
                k: ckpt["state_dict"][k]
                for k in ckpt["state_dict"]
                if ".backbone" in k
            }
            self.load_state_dict(ckpt["state_dict"], strict=False)
🧰 Tools
🪛 Ruff

582-582: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


587-598: ⚠️ Potential issue

Remove debugging print statement and simplify key iteration

  • The print statement on line 596 appears to be for debugging and should be removed or replaced with proper logging.
  • Simplify the dictionary key iteration by removing .keys() as per the static analysis hint.

Apply this diff to address both issues:

             }
-            for k in ckpt["state_dict"].keys()
+            for k in ckpt["state_dict"]
             if ".head_layers" in k
             }
-        print(f"from main code: {ckpt['state_dict'].keys()}")

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

593-593: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

@gitttt-1234 gitttt-1234 requested a review from talmo October 31, 2024 14:54
@gitttt-1234 gitttt-1234 force-pushed the divya/add-config-wandb branch from 91d478c to c2a1c04 Compare December 5, 2024 18:09
@gitttt-1234 gitttt-1234 force-pushed the divya/add-head-ckpt-path branch from 2b7efde to 6c33c82 Compare December 5, 2024 20:10
@gitttt-1234 gitttt-1234 changed the base branch from divya/add-config-wandb to main December 5, 2024 20:11
@gitttt-1234 gitttt-1234 changed the base branch from main to divya/add-config-wandb December 5, 2024 20:11
@gitttt-1234 gitttt-1234 force-pushed the divya/add-head-ckpt-path branch from a473848 to e7486da Compare December 5, 2024 20:22
@gitttt-1234 gitttt-1234 changed the base branch from divya/add-config-wandb to main December 5, 2024 20:23
@gitttt-1234 gitttt-1234 force-pushed the divya/add-head-ckpt-path branch from e7486da to 751559a Compare December 5, 2024 20:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (5)
sleap_nn/training/model_trainer.py (5)

497-498: Update docstrings to include new parameters

The __init__ method of TrainingModel now includes the parameters backbone_trained_ckpts_path and head_trained_ckpts_path, but these are not documented in the method's docstring under the Args section. Please update the docstring to include these new parameters and their descriptions to maintain comprehensive documentation.


577-583: Simplify dictionary key iteration by removing .keys()

In line 581, for k in ckpt["state_dict"].keys() can be simplified to for k in ckpt["state_dict"] for more concise and idiomatic Python code.

Apply this diff to simplify the code:

 ckpt["state_dict"] = {
     k: ckpt["state_dict"][k]
-    for k in ckpt["state_dict"].keys()
+    for k in ckpt["state_dict"]
     if ".backbone" in k
 }
🧰 Tools
🪛 Ruff (0.8.0)

582-582: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


592-595: Simplify dictionary key iteration by removing .keys()

Similarly, in line 593, for k in ckpt["state_dict"].keys() can be simplified to for k in ckpt["state_dict"].

Apply this diff to simplify the code:

 ckpt["state_dict"] = {
     k: ckpt["state_dict"][k]
-    for k in ckpt["state_dict"].keys()
+    for k in ckpt["state_dict"]
     if ".head_layers" in k
 }
🧰 Tools
🪛 Ruff (0.8.0)

593-593: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


596-596: Remove debug print statement or replace with logging

The print statement at line 596 appears to be used for debugging purposes:

print(f"from main code: {ckpt['state_dict'].keys()}")

Consider removing this statement or replacing it with a logging statement at an appropriate logging level to avoid cluttering the console output in a production environment.

Apply this diff to remove the print statement:

-        print(f"from main code: {ckpt['state_dict'].keys()}")

497-498: Ensure docstrings in subclass constructors include new parameters

The __init__ methods of subclasses (SingleInstanceModel, TopDownCenteredInstanceModel, CentroidModel, BottomUpModel) now include backbone_trained_ckpts_path and head_trained_ckpts_path as parameters. Please update the docstrings in these methods to include the new parameters and their descriptions in the Args section for consistency and clarity.

Also applies to: 506-507, 713-714, 723-724, 804-805, 885-886, 966-967

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 6c33c82 and 751559a.

📒 Files selected for processing (3)
  • sleap_nn/training/model_trainer.py (14 hunks)
  • tests/fixtures/datasets.py (1 hunks)
  • tests/training/test_model_trainer.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.0)
sleap_nn/training/model_trainer.py

582-582: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


593-593: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🔇 Additional comments (2)
tests/training/test_model_trainer.py (1)

346-383: Verify the use of the same checkpoint for both backbone and head

In the test_trainer_load_trained_ckpts function, both backbone_trained_ckpts_path and head_trained_ckpts_path are set to the same checkpoint file:

backbone_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),
head_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),

Please confirm that using the same checkpoint for both is intentional. If the backbone and head layers have different trained checkpoints, consider using separate checkpoint files to better simulate real-world scenarios.

tests/fixtures/datasets.py (1)

92-92: Verify the impact of disabling interpolation upsampling

The change from up_interpolate: True to False in the backbone configuration will affect how the UNet performs upsampling. This might impact the model's ability to handle pre-trained weights if they were trained with different upsampling methods.

Let's verify if this change is consistent with the model's expectations:

✅ Verification successful

Based on the search results, I can now generate the final response:

Setting up_interpolate to False is consistent with test configurations

The change to up_interpolate: False in the backbone configuration is consistent with the existing test configurations and implementation. The codebase shows:

  1. All minimal test configurations in tests/assets/ use up_interpolate: false
  2. The encoder-decoder architecture (sleap_nn/architectures/encoder_decoder.py) supports both modes:
    • When False: Uses transposed convolutions for upsampling
    • When True: Uses bilinear interpolation (faster but potentially less detailed)
  3. The documentation in docs/config.md confirms this is a valid configuration option, with the tradeoff being that transposed convolutions may learn richer upsampling patterns to recover details from higher scales.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any related configuration or validation of up_interpolate
rg -l "up_interpolate" | xargs rg -A 5 "up_interpolate"

Length of output: 27035

@gitttt-1234 gitttt-1234 merged commit 7532130 into main Dec 5, 2024
7 checks passed
@gitttt-1234 gitttt-1234 deleted the divya/add-head-ckpt-path branch December 5, 2024 20:59
This was referenced Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants