-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to load trained weights for Head layers #114
Conversation
WalkthroughThe changes in this pull request involve modifications to the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #114 +/- ##
==========================================
+ Coverage 96.64% 97.37% +0.73%
==========================================
Files 23 38 +15
Lines 1818 3702 +1884
==========================================
+ Hits 1757 3605 +1848
- Misses 61 97 +36 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caution
Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.
Actionable comments posted: 15
🧹 Outside diff range and nitpick comments (19)
tests/assets/minimal_instance_centroid/initial_config.yaml (2)
64-64
: Document the bin_files_path parameter.Given that this PR introduces functionality for loading pre-trained weights:
- Consider adding a comment explaining the expected format and purpose of bin_files_path
- Should there be validation to ensure the path exists when provided?
- How does this relate to the backbone_trained_ckpts_path and head_trained_ckpts_path mentioned in the PR description?
78-85
: LGTM! Consider adding parameter descriptions.The ReduceLROnPlateau configuration looks well-structured with reasonable default values. To improve maintainability, consider adding inline comments explaining each parameter's purpose and impact on training.
Example documentation:
reduce_lr_on_plateau: # Minimum change in loss to be considered an improvement threshold: 1.0e-07 # How to measure the threshold (abs or rel) threshold_mode: abs # Number of epochs to wait after a reduction before resuming normal operation cooldown: 3 # Number of epochs with no improvement after which LR will be reduced patience: 5 # Factor by which the learning rate will be reduced factor: 0.5 # Lower bound on the learning rate min_lr: 1.0e-08tests/assets/minimal_instance/initial_config.yaml (1)
13-14
: Document the significance of crop_hw values.The hardcoded values [160, 160] for crop_hw should be documented to explain their significance and any constraints.
crop_hw: - - 160 - - 160 + - 160 # Height of the crop window + - 160 # Width of the crop windowtests/assets/minimal_instance_bottomup/initial_config.yaml (2)
70-71
: Document the purpose and format ofbin_files_path
.The PR mentions loading
.ckpt
format weights, butbin_files_path
suggests binary files. Please clarify:
- The expected file format
- Whether this is related to the pre-trained weights loading feature
- Add documentation for this parameter's usage
84-91
: LGTM! Consider adding parameter descriptions.The learning rate scheduler configuration is well-structured with appropriate values for ReduceLROnPlateau. Consider adding inline comments or documentation explaining the purpose of each parameter for future maintainability.
lr_scheduler: scheduler: ReduceLROnPlateau reduce_lr_on_plateau: threshold: 1.0e-07 # Minimum change in loss to be considered an improvement threshold_mode: abs # Use absolute change in loss cooldown: 3 # Epochs to wait before resuming normal operation patience: 5 # Epochs to wait before reducing LR factor: 0.5 # Factor to reduce LR by min_lr: 1.0e-08 # Minimum LR thresholdtests/assets/minimal_instance_centroid/training_config.yaml (1)
75-75
: Document thebin_files_path
parameter.Since this parameter is related to loading trained weights (PR objective), please:
- Document the expected path format and structure
- Add validation to handle empty/invalid paths
- Consider providing an example path in the config
tests/assets/minimal_instance_bottomup/training_config.yaml (1)
99-106
: LGTM: Well-configured learning rate scheduler.The ReduceLROnPlateau configuration is comprehensive and uses sensible values. The scheduler will:
- Reduce learning rate by half when no improvement is seen
- Wait 5 epochs before reducing (patience)
- Allow 3 epochs of cooldown between reductions
- Stop at min_lr of 1e-8
Consider adding comments in the YAML to explain the purpose of less obvious parameters like
threshold_mode
andcooldown
.docs/config_bottomup.yaml (2)
6-6
: Document the purpose of chunk_size parameter.The newly added
chunk_size
parameter would benefit from a comment explaining its purpose and impact on data processing. This helps users understand how to tune this value for their specific use cases.
Parameters should be added to trainer_config section
Based on the implementation in
model_trainer.py
, the parametersbackbone_trained_ckpts_path
andhead_trained_ckpts_path
are used during model initialization and training. These parameters should be added to thetrainer_config
section of the configuration file, not themodel_config
section, as they are handled by the trainer.trainer_config: # ... existing config ... resume_ckpt_path: + backbone_trained_ckpts_path: # Path to trained checkpoint for backbone + head_trained_ckpts_path: # Path to trained checkpoint for head layers optimizer_name: Adam🔗 Analysis chain
Line range hint
1-115
: Missing configuration parameters for trained weights loading.According to the PR objectives, this change should add support for loading pre-trained weights for head layers. However, I don't see the new parameters
backbone_trained_ckpts_path
andhead_trained_ckpts_path
in the configuration. These parameters should be added to enable the described functionality.Let's check if these parameters are defined elsewhere:
Consider adding these parameters to the configuration file:
trainer_config: # ... existing config ... resume_ckpt_path: + backbone_trained_ckpts_path: + head_trained_ckpts_path: optimizer_name: Adam🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for the new parameters in other configuration files rg -l "backbone_trained_ckpts_path|head_trained_ckpts_path"Length of output: 133
Script:
#!/bin/bash # Let's check the implementation to understand how these parameters are used rg -A 10 "backbone_trained_ckpts_path|head_trained_ckpts_path" tests/training/test_model_trainer.py sleap_nn/training/model_trainer.pyLength of output: 17142
docs/config_topdown_centered_instance.yaml (1)
Line range hint
1-114
: Missing configuration for trained weights loading.Based on the PR objectives, this configuration file should include parameters for specifying trained weight paths:
backbone_trained_ckpts_path
head_trained_ckpts_path
These parameters are essential for the PR's core functionality of loading pre-trained weights for head layers.
Consider adding these parameters under the appropriate section (possibly under
model_config
ortrainer_config
).docs/config_centroid.yaml (1)
Line range hint
89-93
: Add configuration for head layer pre-trained weights.The PR aims to add support for loading pre-trained weights for head layers, but the configuration is missing the necessary parameters. Consider adding
head_trained_ckpts_path
under thehead_configs
section to align with the PR's objective.Example addition:
head_configs: single_instance: centered_instance: bottomup: centroid: + trained_weights_path: # Path to pre-trained weights for centroid head confmaps: anchor_part: 0
🧰 Tools
🪛 yamllint
[error] 8-8: trailing spaces
(trailing-spaces)
[error] 9-9: trailing spaces
(trailing-spaces)
tests/fixtures/datasets.py (2)
62-62
: Document the purpose of min_crop_size parameter.The new
min_crop_size
parameter has been added to preprocessing configuration. Please add documentation explaining:
- The purpose of this parameter
- The impact when set to None
- Valid values and their effects
159-167
: LGTM! Consider documenting scheduler parameters.The restructured LR scheduler configuration with ReduceLROnPlateau is well-organized and provides comprehensive control over learning rate adjustments. The parameters are sensibly configured with:
- Relative threshold mode with small threshold (1e-07)
- Conservative reduction factor (0.5)
- Reasonable patience (5) and cooldown (3) periods
- Safety floor for minimum learning rate (1e-08)
Consider adding inline comments or documentation explaining the purpose of each parameter and their recommended ranges.
tests/training/test_model_trainer.py (2)
108-110
: Consider adding edge cases for StepLR parameters.While the basic StepLR configuration is tested, consider adding test cases for:
- Edge case values for
step_size
(e.g., 0, negative)- Edge case values for
gamma
(e.g., 0, negative, >1)- Verification of learning rate changes after each step
337-342
: Enhance scheduler type validation tests.Consider improving the error handling tests:
- Test multiple invalid scheduler types
- Verify the error message content
- Test edge cases (empty string, None, etc.)
Example:
invalid_schedulers = ["ReduceLR", "", None, "CustomScheduler"] for scheduler in invalid_schedulers: OmegaConf.update(config, "trainer_config.lr_scheduler.scheduler", scheduler) with pytest.raises(ValueError, match=f"Unsupported scheduler type: {scheduler}"): trainer = ModelTrainer(config) trainer.train()docs/config.md (1)
178-188
: Technical review of learning rate scheduler documentationThe learning rate scheduler documentation is well-structured but could benefit from additional clarity:
- For
StepLR
, consider adding an example to illustrate the decay pattern.- For
ReduceLROnPlateau
, thethreshold_mode
parameter's explanation could be clearer.Consider updating the documentation with this improved explanation:
- `threshold_mode`: (str) One of "rel", "abs". In rel mode, dynamic_threshold = best * ( 1 + threshold ) in max mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. *Default*: "rel". + - `threshold_mode`: (str) One of "rel" (relative) or "abs" (absolute). Determines how to interpret the threshold parameter: + - In "rel" mode: Triggers when improvement is less than (best * threshold) for min mode, or (best * (1 + threshold)) for max mode + - In "abs" mode: Triggers when improvement is less than (best - threshold) for min mode, or (best + threshold) for max mode + *Default*: "rel"🧰 Tools
🪛 LanguageTool
[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". -
step_lr: -
step_size`: (int) Period...(UNLIKELY_OPENING_PUNCTUATION)
[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. -reduce_lr_on_plateau
: -threshold
: (float) Thre...(UNLIKELY_OPENING_PUNCTUATION)
🪛 Markdownlint
178-178: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
179-179: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
180-180: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
181-181: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
182-182: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
183-183: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
184-184: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
185-185: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
186-186: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
187-187: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
188-188: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
184-184: null
Spaces inside emphasis markers(MD037, no-space-in-emphasis)
184-184: null
Spaces inside emphasis markers(MD037, no-space-in-emphasis)
sleap_nn/training/model_trainer.py (3)
497-498
: Include parameter types in docstringsIn the docstring for
TrainingModel
, you have added descriptions forbackbone_trained_ckpts_path
andhead_trained_ckpts_path
. Consider specifying the parameter types for better clarity.Apply this diff to enhance the docstring:
head_trained_ckpts_path: Path to trained ckpts for head layer. + + Args: + backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone. + head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.
713-714
: Add parameter types in docstringsFor consistency and clarity, include the parameter types in the docstrings for
SingleInstanceModel
.Apply this diff to enhance the docstring:
head_trained_ckpts_path: Path to trained ckpts for head layer. + + Args: + backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone. + head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.
882-895
: Update docstrings and parameter types forCentroidModel
Include the parameter types in the docstrings to improve clarity.
Apply this diff:
head_trained_ckpts_path: Path to trained ckpts for head layer. + + Args: + backbone_trained_ckpts_path (Optional[str]): Path to trained checkpoints for backbone. + head_trained_ckpts_path (Optional[str]): Path to trained checkpoints for head layers.
🛑 Comments failed to post (15)
tests/assets/minimal_instance/initial_config.yaml (1)
15-15: 💡 Codebase verification
⚠️ Potential issuemin_crop_size parameter should have an integer value
The
min_crop_size
parameter is used to specify the minimum crop size whencrop_hw
is not set. Based on the implementation insleap_nn/data/instance_cropping.py
, it expects an integer value or None. Please either:
- Set an appropriate integer value for
min_crop_size
- Remove the parameter if you want to use the default behavior (equivalent to setting it to 0)
🔗 Analysis chain
Specify a value for min_crop_size or remove if not needed.
The
min_crop_size
parameter is added but left empty, which could lead to validation issues or undefined behavior.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check if min_crop_size is used in the codebase rg -l "min_crop_size" --type pythonLength of output: 161
Script:
#!/bin/bash # Check if min_crop_size is used in the codebase with ripgrep rg -l "min_crop_size" # Also search for any YAML files that might reference this parameter fd -e yaml -x grep -l "min_crop_size" {}Length of output: 586
Script:
#!/bin/bash # Check the implementation and usage of min_crop_size in the core code rg "min_crop_size" sleap_nn/data/instance_cropping.py -C 2 # Check documentation for min_crop_size rg "min_crop_size" docs/config.md -C 2Length of output: 2095
tests/assets/minimal_instance_centroid/training_config.yaml (2)
89-96: 🛠️ Refactor suggestion
Review learning rate scheduler parameters.
The ReduceLROnPlateau configuration needs attention:
threshold: 1e-7
is very small - consider a larger value like 1e-4patience: 5
epochs might be too short for complex modelsmin_lr: 1e-8
is extremely small - typical values are around 1e-6Consider these adjustments:
reduce_lr_on_plateau: - threshold: 1.0e-07 + threshold: 1.0e-04 - patience: 5 + patience: 10 - min_lr: 1.0e-08 + min_lr: 1.0e-06📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.scheduler: ReduceLROnPlateau reduce_lr_on_plateau: threshold: 1.0e-04 threshold_mode: abs cooldown: 3 patience: 10 factor: 0.5 min_lr: 1.0e-06
75-77:
⚠️ Potential issueImplementation incomplete for PR objective.
The configuration is missing the parameters mentioned in the PR objective:
backbone_trained_ckpts_path
head_trained_ckpts_path
These parameters are essential for loading pre-trained weights for head layers.
Add the missing parameters:
bin_files_path: + backbone_trained_ckpts_path: + head_trained_ckpts_path: resume_ckpt_path:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.bin_files_path: backbone_trained_ckpts_path: head_trained_ckpts_path: resume_ckpt_path: wandb:
docs/config_topdown_centered_instance.yaml (1)
15-15:
⚠️ Potential issueSpecify a value for min_crop_size or document if it's optional.
The
min_crop_size
parameter is added without a value. This could lead to:
- Configuration validation errors
- Unexpected runtime behavior
Please either:
- Set a default value
- Document if this is an optional parameter
- Add validation to handle empty values
initial_config.yaml (4)
1-104:
⚠️ Potential issueFix line endings for cross-platform compatibility
The file uses incorrect line endings. Ensure consistent use of Unix-style line endings (\n) across the codebase.
Run the following command to fix line endings:
#!/bin/bash # Convert CRLF to LF sed -i 's/\r$//' initial_config.yaml🧰 Tools
🪛 yamllint
[error] 1-1: wrong new line character: expected \n
(new-lines)
84-84:
⚠️ Potential issueSecurity: Remove empty API key field
For security reasons, sensitive information like API keys should not be included in configuration files, even if empty. Consider removing this field and using environment variables instead.
- api_key: '' + # API key should be provided via environment variable WANDB_API_KEY📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.# API key should be provided via environment variable WANDB_API_KEY
3-4:
⚠️ Potential issueReplace hardcoded paths with relative paths
The configuration contains Windows-specific absolute paths, which will break when used on different machines or operating systems.
Replace absolute paths with relative paths:
- train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp - val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp + train_labels_path: tests/assets/minimal_instance.pkg.slp + val_labels_path: tests/assets/minimal_instance.pkg.slp📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.train_labels_path: tests/assets/minimal_instance.pkg.slp val_labels_path: tests/assets/minimal_instance.pkg.slp
27-28:
⚠️ Potential issueAdd configuration for backbone and head trained weights
According to the PR objectives, the configuration should support
backbone_trained_ckpts_path
andhead_trained_ckpts_path
for loading pre-trained weights. However, these parameters are missing from the configuration.Add the missing parameters:
init_weights: default pre_trained_weights: null +backbone_trained_ckpts_path: null +head_trained_ckpts_path: null backbone_type: unet📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.init_weights: default pre_trained_weights: null backbone_trained_ckpts_path: null head_trained_ckpts_path: null
tests/training/test_model_trainer.py (1)
346-384: 🛠️ Refactor suggestion
Enhance test coverage for weight loading functionality.
The current test verifies basic weight loading, but consider adding tests for:
- Loading only backbone weights
- Loading only head weights
- Invalid checkpoint paths
- Incompatible model architectures
- Verifying weights are maintained after training
Example additions:
# Test loading only backbone weights trainer = ModelTrainer(load_weights_config) trainer._initialize_model( backbone_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(), head_trained_ckpts_path=None ) model_ckpt = next(trainer.model.parameters())[0, 0, :].detach().numpy() assert np.all(np.abs(first_layer_ckpt - model_ckpt) < 1e-6) # Test invalid checkpoint path with pytest.raises(FileNotFoundError): trainer._initialize_model( backbone_trained_ckpts_path="invalid_path.ckpt", head_trained_ckpts_path=None ) # Test incompatible architecture incompatible_ckpt = {...} # Create checkpoint with different architecture with pytest.raises(ValueError, match="Incompatible model architecture"): trainer._initialize_model( backbone_trained_ckpts_path=incompatible_ckpt, head_trained_ckpts_path=None )docs/config.md (1)
178-188:
⚠️ Potential issueMarkdown formatting improvements needed
The indentation of list items is inconsistent with the rest of the document. The static analysis tool suggests using 4-space indentation for unordered lists.
Apply these formatting changes:
- Adjust indentation to match the rest of the document (4 spaces for top-level items, 6 spaces for nested items)
- Remove extra spaces around emphasis markers in the
threshold_mode
description- Fix loose punctuation marks after the scheduler options
- - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`. - - `step_lr`: - - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`. + - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"` + - `step_lr`: + - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 LanguageTool
[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". -
step_lr: -
step_size`: (int) Period...(UNLIKELY_OPENING_PUNCTUATION)
[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. -reduce_lr_on_plateau
: -threshold
: (float) Thre...(UNLIKELY_OPENING_PUNCTUATION)
🪛 Markdownlint
178-178: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
179-179: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
180-180: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
181-181: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
182-182: Expected: 4; Actual: 8
Unordered list indentation(MD007, ul-indent)
183-183: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
184-184: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
185-185: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
186-186: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
187-187: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
188-188: Expected: 6; Actual: 12
Unordered list indentation(MD007, ul-indent)
184-184: null
Spaces inside emphasis markers(MD037, no-space-in-emphasis)
184-184: null
Spaces inside emphasis markers(MD037, no-space-in-emphasis)
sleap_nn/training/model_trainer.py (5)
332-336: 🛠️ Refactor suggestion
Update docstrings to reflect new parameters
The
_initialize_model
method now accepts two new parameters:backbone_trained_ckpts_path
andhead_trained_ckpts_path
. Please update the method's docstring to include descriptions for these parameters to maintain clarity and consistency in documentation.
455-455:
⚠️ Potential issueAvoid logging sensitive configuration data
When updating the WandB experiment configuration, ensure that no sensitive information (such as API keys or personal data) is being inadvertently logged.
517-518: 🛠️ Refactor suggestion
Remove redundancy in attribute assignment
You are already passing
backbone_trained_ckpts_path
andhead_trained_ckpts_path
to the superclass. Assigning them again here may be unnecessary.Apply this diff to remove redundant assignments:
- self.backbone_trained_ckpts_path = backbone_trained_ckpts_path - self.head_trained_ckpts_path = head_trained_ckpts_pathCommittable suggestion was skipped due to low confidence.
576-585:
⚠️ Potential issueSimplify dictionary key iteration
As per the static analysis hint, you can simplify the dictionary iteration by removing the
.keys()
method.Apply this diff to address the issue:
- for k in ckpt["state_dict"].keys() + for k in ckpt["state_dict"] if ".backbone" in k📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.# Initializing backbone (encoder + decoder) with trained ckpts if backbone_trained_ckpts_path is not None: print(f"Loading backbone weights from `{backbone_trained_ckpts_path}` ...") ckpt = torch.load(backbone_trained_ckpts_path) ckpt["state_dict"] = { k: ckpt["state_dict"][k] for k in ckpt["state_dict"] if ".backbone" in k } self.load_state_dict(ckpt["state_dict"], strict=False)
🧰 Tools
🪛 Ruff
582-582: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
587-598:
⚠️ Potential issueRemove debugging
- The
- Simplify the dictionary key iteration by removing
.keys()
as per the static analysis hint.Apply this diff to address both issues:
} - for k in ckpt["state_dict"].keys() + for k in ckpt["state_dict"] if ".head_layers" in k } - print(f"from main code: {ckpt['state_dict'].keys()}")Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
593-593: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
91d478c
to
c2a1c04
Compare
2b7efde
to
6c33c82
Compare
a473848
to
e7486da
Compare
e7486da
to
751559a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (5)
sleap_nn/training/model_trainer.py (5)
497-498
: Update docstrings to include new parametersThe
__init__
method ofTrainingModel
now includes the parametersbackbone_trained_ckpts_path
andhead_trained_ckpts_path
, but these are not documented in the method's docstring under theArgs
section. Please update the docstring to include these new parameters and their descriptions to maintain comprehensive documentation.
577-583
: Simplify dictionary key iteration by removing.keys()
In line 581,
for k in ckpt["state_dict"].keys()
can be simplified tofor k in ckpt["state_dict"]
for more concise and idiomatic Python code.Apply this diff to simplify the code:
ckpt["state_dict"] = { k: ckpt["state_dict"][k] - for k in ckpt["state_dict"].keys() + for k in ckpt["state_dict"] if ".backbone" in k }🧰 Tools
🪛 Ruff (0.8.0)
582-582: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
592-595
: Simplify dictionary key iteration by removing.keys()
Similarly, in line 593,
for k in ckpt["state_dict"].keys()
can be simplified tofor k in ckpt["state_dict"]
.Apply this diff to simplify the code:
ckpt["state_dict"] = { k: ckpt["state_dict"][k] - for k in ckpt["state_dict"].keys() + for k in ckpt["state_dict"] if ".head_layers" in k }🧰 Tools
🪛 Ruff (0.8.0)
593-593: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
596-596
: Remove debug print statement or replace with loggingThe
print(f"from main code: {ckpt['state_dict'].keys()}")Consider removing this statement or replacing it with a logging statement at an appropriate logging level to avoid cluttering the console output in a production environment.
Apply this diff to remove the print statement:
- print(f"from main code: {ckpt['state_dict'].keys()}")
497-498
: Ensure docstrings in subclass constructors include new parametersThe
__init__
methods of subclasses (SingleInstanceModel
,TopDownCenteredInstanceModel
,CentroidModel
,BottomUpModel
) now includebackbone_trained_ckpts_path
andhead_trained_ckpts_path
as parameters. Please update the docstrings in these methods to include the new parameters and their descriptions in theArgs
section for consistency and clarity.Also applies to: 506-507, 713-714, 723-724, 804-805, 885-886, 966-967
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
sleap_nn/training/model_trainer.py
(14 hunks)tests/fixtures/datasets.py
(1 hunks)tests/training/test_model_trainer.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.0)
sleap_nn/training/model_trainer.py
582-582: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
593-593: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
🔇 Additional comments (2)
tests/training/test_model_trainer.py (1)
346-383
: Verify the use of the same checkpoint for both backbone and head
In the test_trainer_load_trained_ckpts
function, both backbone_trained_ckpts_path
and head_trained_ckpts_path
are set to the same checkpoint file:
backbone_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),
head_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),
Please confirm that using the same checkpoint for both is intentional. If the backbone and head layers have different trained checkpoints, consider using separate checkpoint files to better simulate real-world scenarios.
tests/fixtures/datasets.py (1)
92-92
: Verify the impact of disabling interpolation upsampling
The change from up_interpolate: True
to False
in the backbone configuration will affect how the UNet performs upsampling. This might impact the model's ability to handle pre-trained weights if they were trained with different upsampling methods.
Let's verify if this change is consistent with the model's expectations:
✅ Verification successful
Based on the search results, I can now generate the final response:
Setting up_interpolate to False is consistent with test configurations
The change to up_interpolate: False
in the backbone configuration is consistent with the existing test configurations and implementation. The codebase shows:
- All minimal test configurations in
tests/assets/
useup_interpolate: false
- The encoder-decoder architecture (
sleap_nn/architectures/encoder_decoder.py
) supports both modes:- When False: Uses transposed convolutions for upsampling
- When True: Uses bilinear interpolation (faster but potentially less detailed)
- The documentation in
docs/config.md
confirms this is a valid configuration option, with the tradeoff being that transposed convolutions may learn richer upsampling patterns to recover details from higher scales.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any related configuration or validation of up_interpolate
rg -l "up_interpolate" | xargs rg -A 5 "up_interpolate"
Length of output: 27035
This PR adds an option to load trained weights for head layers right before training (instead of random initialization) for transfer learning. The
.ckpt
paths can be passed toModelTrainer.train(backbone_trained_ckpts_path=<your_path>, head_trained_ckpts_path=<your_path>)
.Summary by CodeRabbit
New Features
Bug Fixes
Tests