Skip to content

Commit

Permalink
Merge branch 'master' into region-catalog-script
Browse files Browse the repository at this point in the history
  • Loading branch information
aakashdp6548 authored Sep 6, 2023
2 parents 6e801a7 + 2590a63 commit 4cb710d
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 73 deletions.
1 change: 1 addition & 0 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ encoder:
log_transform: false
rolling_z_score: true
do_data_augmentation: false
compile_model: false # if true, compile model for potential performance
architecture:
# this architecture is based on yolov5l.yaml, see
# https://github.com/ultralytics/yolov5/blob/master/models/yolov5l.yaml
Expand Down
4 changes: 4 additions & 0 deletions bliss/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
scheduler_params: Optional[dict] = None,
input_transform_params: Optional[dict] = None,
do_data_augmentation: bool = False,
compile_model: bool = False,
):
"""Initializes DetectionEncoder.
Expand All @@ -66,6 +67,7 @@ def __init__(
input_transform_params: used for determining what channels to use as input (e.g.
deconvolution, concatenate PSF parameters, z-score inputs, etc.)
do_data_augmentation: used for determining whether or not do data augmentation
compile_model: compile model for potential performance improvements
"""
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -104,6 +106,8 @@ def __init__(
num_channels = len(self.bands)
num_features_per_band = self._get_num_features()
self.model = Backbone(cfg=arch_dict, ch=num_channels, n_imgs=num_features_per_band)
if compile_model:
self.model = torch.compile(self.model)
self.tiles_to_crop = tiles_to_crop

# metrics
Expand Down
Loading

0 comments on commit 4cb710d

Please sign in to comment.