From 86af2a39165861104e829f5c7efc1f188686622e Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Fri, 22 Nov 2024 18:03:16 +0100 Subject: [PATCH] Fix explicit augmentations passing in the fit method --- spotiflow/model/spotiflow.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 502925d..8fd4500 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -415,11 +415,9 @@ def fit( transforms.Crop if not self.config.is_3d else transforms3d.Crop3D ) assert any( - isinstance(p, _crop_cls) for p in augment_train.transforms + isinstance(p, _crop_cls) for p in augment_train.augmentations ), "Custom augmenter must contain a cropping transform!" - tr_augmenter = self.build_image_augmenter( - crop_size, point_priority=point_priority - ) + tr_augmenter = augment_train elif augment_train: tr_augmenter = self.build_image_augmenter( crop_size, point_priority=point_priority