Skip to content

Commit

Permalink
Remove eval
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Jan 6, 2025
1 parent 3080a9a commit 0487374
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@
)
from sleap_nn.training.utils import check_memory, xavier_init_weights

MODEL_WEIGHTS = {
"Swin_T_Weights": Swin_T_Weights,
"Swin_S_Weights": Swin_S_Weights,
"Swin_B_Weights": Swin_B_Weights,
"Swin_V2_T_Weights": Swin_V2_T_Weights,
"Swin_V2_S_Weights": Swin_V2_S_Weights,
"Swin_V2_B_Weights": Swin_V2_B_Weights,
"ConvNeXt_Base_Weights": ConvNeXt_Base_Weights,
"ConvNeXt_Tiny_Weights": ConvNeXt_Tiny_Weights,
"ConvNeXt_Small_Weights": ConvNeXt_Small_Weights,
"ConvNeXt_Large_Weights": ConvNeXt_Large_Weights,
}


class ModelTrainer:
"""Train sleap-nn model using PyTorch Lightning.
Expand Down Expand Up @@ -788,9 +801,9 @@ def __init__(
self.head_trained_ckpts_path = head_trained_ckpts_path
self.input_expand_channels = self.model_config.backbone_config.in_channels
if self.model_config.pre_trained_weights: # only for swint and convnext
ckpt = eval(self.model_config.pre_trained_weights).DEFAULT.get_state_dict(
progress=True, check_hash=True
)
ckpt = MODEL_WEIGHTS[
self.model_config.pre_trained_weights
].DEFAULT.get_state_dict(progress=True, check_hash=True)
input_channels = ckpt["features.0.0.weight"].shape[-3]
if self.model_config.backbone_config.in_channels != input_channels:
self.input_expand_channels = input_channels
Expand Down

0 comments on commit 0487374

Please sign in to comment.