Skip to content

Commit

Permalink
Fix efficient ad (#2015)
Browse files Browse the repository at this point in the history
* Remove batch_size from config, make it hardcoded in the model (it is batch size for the imagenet dataloader, which should not be changed).

* Better description

* Removed Imagenet normalization, added check-up for train_batch_size

* Fix train_batch_size for testing EfficientAd

* Updated usage example
  • Loading branch information
abc-125 authored Apr 28, 2024
1 parent 81cd4c5 commit aee41f2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
3 changes: 1 addition & 2 deletions configs/model/efficient_ad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ model:
weight_decay: 1.0e-05
padding: false
pad_maps: true
batch_size: 1

metrics:
pixel:
- AUROC

trainer:
max_epochs: 200
max_epochs: 1000
max_steps: 70000
2 changes: 1 addition & 1 deletion src/anomalib/models/image/efficient_ad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Anomalies are detected as the difference in output feature maps between the teac

## Usage

`python tools/train.py --model efficient_ad`
`anomalib train --model EfficientAd --data anomalib.data.MVTec --data.train_batch_size 1`

## Benchmark

Expand Down
21 changes: 13 additions & 8 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class EfficientAd(AnomalyModule):
pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
output anomaly maps so that their size matches the size in the padding = True case.
Defaults to ``True``.
batch_size (int): batch size for imagenet dataloader
Defaults to ``1``.
"""

def __init__(
Expand All @@ -71,7 +69,6 @@ def __init__(
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
batch_size: int = 1,
) -> None:
super().__init__()

Expand All @@ -83,7 +80,7 @@ def __init__(
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = batch_size
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr = lr
self.weight_decay = weight_decay

Expand Down Expand Up @@ -237,9 +234,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
def on_train_start(self) -> None:
"""Called before the first training epoch.
First sets up the pretrained teacher model, then prepares the imagenette data, and finally calculates or
loads the channel-wise mean and std of the training dataset and push to the model.
First check if EfficientAd-specific parameters are set correctly (train_batch_size of 1
and no Imagenet normalization in transforms), then sets up the pretrained teacher model,
then prepares the imagenette data, and finally calculates or loads
the channel-wise mean and std of the training dataset and push to the model.
"""
if self.trainer.datamodule.train_batch_size != 1:
msg = "train_batch_size for EfficientAd should be 1."
raise ValueError(msg)
if self._transform and any(isinstance(transform, Normalize) for transform in self._transform.transforms):
msg = "Transforms for EfficientAd should not contain Normalize."
raise ValueError(msg)

sample = next(iter(self.trainer.train_dataloader))
image_size = sample["image"].shape[-2:]
self.prepare_pretrained_model()
Expand Down Expand Up @@ -314,11 +320,10 @@ def learning_type(self) -> LearningType:
return LearningType.ONE_CLASS

def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform:
"""Default transform for Padim."""
"""Default transform for EfficientAd. Imagenet normalization applied in forward."""
image_size = image_size or (256, 256)
return Compose(
[
Resize(image_size, antialias=True),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)
3 changes: 2 additions & 1 deletion tests/integration/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def _get_objects(
root=dataset_path / "mvtec",
category="dummy",
task=task_type,
train_batch_size=2,
# EfficientAd requires train batch size 1
train_batch_size=1 if model_name == "efficient_ad" else 2,
)

model = get_model(model_name, **extra_args)
Expand Down

0 comments on commit aee41f2

Please sign in to comment.