Skip to content

Commit

Permalink
bug fix mvtec
Browse files Browse the repository at this point in the history
  • Loading branch information
fvmassoli committed Feb 26, 2021
1 parent ee1b68c commit c1ec0a7
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 50 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ bin/*
*.sh
output
*.pyc
*.log
*.log
*__pycache__*
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ It reports a new technique to detect anomalies...
The current version of the code requires python 3.6 and pytorch ...


Minimal usage:
Minimal usage (CIFAR10):

```
python -W ignore ...
python3 main_cifar10.py -ptr -tr -tt -zl 128 -nc <normal class> -dp <path to CIFAR10 dataset>
```

Minimal usage (MVTec):

```
python3 main_mvtec.py -ptr -tr -tt -zl 128 -nc <normal class> -dp <path to CIFAR10 dataset> --use-selector
```


Expand Down
3 changes: 2 additions & 1 deletion datasets/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def get_data_holder(self):
rotation_range = (-45, 45) if self.normal_class in object_classes else (0, 0)

return MVTec_DataHolder(
data_path=self.data_path,
category=self.normal_class,
image_size=image_size,
patch_size=patch_size,
rotation_range=rotation_range,
texture=is_texture
is_texture=is_texture
)
12 changes: 6 additions & 6 deletions datasets/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, data_path: str, category: str, image_size: int, patch_size: i
"""
self.data_path = data_path
self.category = normal_class
self.category = category
self.image_size = image_size
self.patch_size = patch_size
self.rotation_range = rotation_range
Expand All @@ -106,9 +106,9 @@ def get_test_data(self) -> Dataset:
"""
return MVtecDataset(
root=join(self.data_path, f'MVTec_Anomaly/{category}/test'),
root=join(self.data_path, f'{self.category}/test'),
transform=T.Compose([
T.Resize(image_size, interpolation=Image.BILINEAR),
T.Resize(self.image_size, interpolation=Image.BILINEAR),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
Expand All @@ -123,10 +123,10 @@ def get_train_data(self, return_dataset: bool=True):
False for preprocessing purpose only
"""
train_data_dir = join(self.data_path, f'MVTec_Anomaly/{category}/train/')
train_data_dir = join(self.data_path, f'{self.category}/train/')

# Preprocessed output data path
cache_main_dir = join(self.data_path, f'MVTec_Anomaly/processed/{category}')
cache_main_dir = join(self.data_path, f'processed/{self.category}')
os.makedirs(cache_main_dir, exist_ok=True)
cache_file = f'{cache_main_dir}/{self.category}_train_dataset_i-{self.image_size}_p-{self.patch_size}_r-{self.rotation_range[0]}--{self.rotation_range[1]}.npy'

Expand Down Expand Up @@ -166,7 +166,7 @@ def augmentation():
nb_epochs = 50000 // len(train_dataset.imgs)
data_loader = DataLoader(dataset=train_dataset, batch_size=1024, pin_memory=True)

for epoch in tqdm(range(nb_epochs), total=nb_epochs, desc=f"Creating cache for: {category}"):
for epoch in tqdm(range(nb_epochs), total=nb_epochs, desc=f"Creating cache for: {self.category}"):
if epoch == 0:
cache_np = [x.numpy() for x, _ in tqdm(data_loader, total=len(data_loader), desc=f'Caching epoch: {epoch+1}/{nb_epochs+1}', leave=False)]
else:
Expand Down
36 changes: 19 additions & 17 deletions main_mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from datasets.data_manager import DataManager
from models.mvtec_model import MVTecNet_AutoEncoder
from trainers.train_mvtec import pretrain, train, test
from utils import set_seeds, get_out_dir, purge_ae_params, eval_spheres_centers, load_mvtec_model_from_checkpoint
from trainers.trainer_mvtec import pretrain, train, test
from utils import set_seeds, get_out_dir, eval_spheres_centers, load_mvtec_model_from_checkpoint


def test_models(test_loader: DataLoader, net_cehckpoint: str, tables: tuple, out_df: pd.DataFrame, is_texture: bool, input_shape: tuple, idx_list_enc: list, boundary: str, normal_class: str, use_selectors: bool, device: str):
Expand Down Expand Up @@ -156,7 +156,6 @@ def main(args):
f"\n\t\t\t\tTrain model : {args.train}"
f"\n\t\t\t\tTest model : {args.test}"
f"\n\t\t\t\tBoundary : {args.boundary}"
f"\n\t\t\t\tOptimizer : {args.optimizer}"
f"\n\t\t\t\tPretrain epochs : {args.ae_epochs}"
f"\n\t\t\t\tAE-Learning rate : {args.ae_learning_rate}"
f"\n\t\t\t\tAE-milestones : {args.ae_lr_milestones}"
Expand Down Expand Up @@ -198,15 +197,15 @@ def main(args):

# Init DataHolder class
data_holder = DataManager(
dataset_name=args.dataset_name,
dataset_name='MVTec_Anomaly',
data_path=args.data_path,
normal_class=args.normal_class,
only_test=args.test
).get_data_holder()

# Load data
train_loader, test_loader = data_holder.get_loaders(
batch_size=args.batch_szie,
batch_size=args.batch_size,
shuffle_train=True,
pin_memory=device=="cuda",
num_workers=args.n_workers
Expand All @@ -216,7 +215,7 @@ def main(args):
only_test = args.test and not args.train and not args.pretrain
logger.info("Dataset info:")
logger.info(
f"Dataset : {args.dataset_name}"
"\n"
f"\n\t\t\t\tNormal class : {args.normal_class}"
f"\n\t\t\t\tBatch size : {args.batch_size}"
)
Expand All @@ -237,8 +236,8 @@ def main(args):
### PRETRAIN the full AutoEncoder
ae_net_cehckpoint = None
if args.pretrain:
out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, net_name='mvtec')
tb_writer = SummaryWriter(os.path.join(args.output_path, args.dataset_name, str(args.normal_class), 'svdd/tb_runs_pretrain', tmp))
out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='mvtec')
tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/pretrain', tmp))

# Init AutoEncoder
ae_net = MVTecNet_AutoEncoder(input_shape=input_shape, code_length=args.code_length, use_selectors=args.use_selectors)
Expand All @@ -248,7 +247,7 @@ def main(args):
ae_net=ae_net,
train_loader=train_loader,
out_dir=out_dir,
tb_writer=tb_writer
tb_writer=tb_writer,
device=device,
ae_learning_rate=args.ae_learning_rate,
ae_weight_decay=args.ae_weight_decay,
Expand All @@ -267,11 +266,14 @@ def main(args):
if args.model_ckp is None:
logger.info("CANNOT TRAIN MODEL WITHOUT A VALID CHECKPOINT")
sys.exit(0)

ae_net_cehckpoint = args.model_ckp

aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1])
out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, net_name='mvtec')

tb_writer = SummaryWriter(os.path.join(args.output_path, args.dataset_name, str(args.normal_class), 'tb_runs_train', tmp))
out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='mvtec')

tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/train', tmp))

# Init the Encoder network
encoder_net = load_mvtec_model_from_checkpoint(
Expand All @@ -284,7 +286,7 @@ def main(args):
)

## Eval/Load hyperspeheres centers
centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, debug=args.debug)
centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, device=device, debug=args.debug)

# If we do not select any layer, then use only the last one
# Remove all hyperspheres' center but the last one
Expand All @@ -294,7 +296,7 @@ def main(args):

# Start training
net_cehckpoint = train(
net=net,
net=encoder_net,
train_loader=train_loader,
centers=centers_,
out_dir=out_dir,
Expand Down Expand Up @@ -377,7 +379,7 @@ def main(args):
## General config
parser.add_argument('-s', '--seed', type=int, default=-1, help='Random seed (default: -1)')
parser.add_argument('--n_workers', type=int, default=8, help='Number of workers for data loading. 0 means that the data will be loaded in the main process. (default: 8)')
parser.add_argument('--output_path', default='./output/mvtec_ad')
parser.add_argument('--output_path', default='./output')
parser.add_argument('-lf', '--log-frequency', type=int, default=5, help='Log frequency (default: 5)')
parser.add_argument('-dl', '--disable-logging', action="store_true", help='Disabel logging (default: False)')
## Model config
Expand All @@ -390,16 +392,16 @@ def main(args):
parser.add_argument('-wd', '--weight-decay', type=float, default=0.5e-6, help='Learning rate (default: 0.5e-6)')
parser.add_argument('-aml', '--ae-lr-milestones', type=int, nargs='+', default=[], help='Pretrain milestone')
parser.add_argument('-ml', '--lr-milestones', type=int, nargs='+', default=[], help='Training milestone')
## Data
parser.add_argument('-dp', '--data-path', default='./MVTec_Anomaly', help='Dataset main path')
parser.add_argument('-nc', '--normal-class', choices=('bottle', 'capsule', 'grid', 'leather', 'metal_nut', 'screw', 'toothbrush', 'wood', 'cable', 'carpet', 'hazelnut', 'pill', 'tile', 'transistor', 'zipper'), default='cable', help='Category (default: cable)')
## Training config
parser.add_argument('-we', '--warm_up_n_epochs', type=int, default=5, help='Warm up epochs (default: 5)')
parser.add_argument('--use-selectors', action="store_true", help='Use features selector (default: False)')
parser.add_argument('-ba', '--batch-accumulation', type=int, default=-1, help='Batch accumulation (default: -1, i.e., None)')
parser.add_argument('-ptr', '--pretrain', action="store_true", help='Pretrain model (default: False)')
parser.add_argument('-tr', '--train', action="store_true", help='Train model (default: False)')
parser.add_argument('-tt', '--test', action="store_true", help='Test model (default: False)')
parser.add_argument('-dn', '--dataset-name', default='MVTec_Anomaly')
parser.add_argument('-ul', '--unlabelled-data', action="store_true", help='Use unlabelled data (default: False)')
parser.add_argument('-nc', '--normal-class', choices=('bottle', 'capsule', 'grid', 'leather', 'metal_nut', 'screw', 'toothbrush', 'wood', 'cable', 'carpet', 'hazelnut', 'pill', 'tile', 'transistor', 'zipper'), default='cable', help='Category (default: cable)')
parser.add_argument('-tbc', '--train-best-conf', action="store_true", help='Train best configurations (default: False)')
parser.add_argument('-db', '--debug', action="store_true", help='Debug (default: False)')
parser.add_argument('-bs', '--batch-size', type=int, default=128, help='Batch size (default: 128)')
Expand Down
20 changes: 10 additions & 10 deletions models/mvtec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def init_conv_blocks(channel_in: int, channel_out: int, activation_fn: nn) -> nn
Output features size
"""
return DownsampleBlock(channel_in=channel_in, channel_out=channel_out, activation_fn=self.activation_fn)
return DownsampleBlock(channel_in=channel_in, channel_out=channel_out, activation_fn=activation_fn)


class Selector(nn.Module):
Expand Down Expand Up @@ -77,7 +77,7 @@ def forward(self, *input: torch.Tensor) -> torch.Tensor:
return self.fc(input)


class MVtec_Encoder(BaseModule):
class MVTec_Encoder(BaseModule):
"""MVtec Encoder network
"""
Expand Down Expand Up @@ -140,8 +140,8 @@ def get_depths_info(self) -> [int, int]:
"""
return self.last_depth, self.deepest_shape

def forward(self, *input: torch.Tensor) -> torch.Tensor:
o1 = self.conv(input)
def forward(self, x: torch.Tensor) -> torch.Tensor:
o1 = self.conv(x)
o2 = self.res(self.activation_fn(o1))
o3 = self.dwn1(o2)
o4 = self.dwn2(o3)
Expand Down Expand Up @@ -220,8 +220,8 @@ def __init__(self, code_length: int, deepest_shape: int, last_depth: int, output
nn.Conv2d(in_channels=CHANNELS[0], out_channels=3, kernel_size=1, bias=False)
)

def forward(self, *input: torch.Tensor) -> torch.Tensor:
h = self.fc(input)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.fc(x)
h = h.view(len(h), *self.deepest_shape)
return self.conv(h)

Expand Down Expand Up @@ -249,7 +249,7 @@ def __init__(self, input_shape: int, code_length: int, use_selectors: bool):
self.input_shape = input_shape

# Build Encoder
self.encoder = MVtecEncoder(
self.encoder = MVTec_Encoder(
input_shape=input_shape,
code_length=code_length,
idx_list_enc=[],
Expand All @@ -259,15 +259,15 @@ def __init__(self, input_shape: int, code_length: int, use_selectors: bool):
last_depth, deepest_shape = self.encoder.get_depths_info()

# Build Decoder
self.decoder = MVtecDecoder(
self.decoder = MVTec_Decoder(
code_length=code_length,
deepest_shape=deepest_shape,
last_depth=last_depth,
output_shape=input_shape
)

def forward(self, *input: torch.Tensor) -> torch.Tensor:
z = self.encoder(input)
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.encoder(x)
x_r = self.decoder(z)
x_r = x_r.view(-1, *self.input_shape)
return x_r
Empty file added trainers/__init__.py
Empty file.
11 changes: 9 additions & 2 deletions trainers/trainer_mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

from tensorboardX import SummaryWriter
from sklearn.metrics import roc_curve, roc_auc_score, auc


Expand Down Expand Up @@ -68,6 +69,8 @@ def pretrain(ae_net: nn.Module, train_loader: DataLoader, out_dir: str, tb_write
n_batches = 0
optimizer.zero_grad()
for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
if idx == 3 : break

if isinstance(data, list): data = data[0]

data = data.to(device)
Expand Down Expand Up @@ -112,7 +115,7 @@ def pretrain(ae_net: nn.Module, train_loader: DataLoader, out_dir: str, tb_write
return ae_net_cehckpoint


def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir: str, tb_writer: SummaryWriter, device: str, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str) -> :
def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir: str, tb_writer: SummaryWriter, device: str, learning_rate: float, weight_decay: float, lr_milestones: list, epochs: int, nu: float, boundary: str) -> str:
"""Train the Encoder network on the one class task.
Parameters
Expand Down Expand Up @@ -170,6 +173,8 @@ def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir
optimizer.zero_grad()

for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
if idx == 3 : break

data = data.to(device)

zipped = net(data)
Expand Down Expand Up @@ -280,7 +285,9 @@ def test(category: str, is_texture: bool, net: nn.Module, test_loader: DataLoade
logger.info('Start testing...')

idx_label_score = []

net.eval().to(device)

with torch.no_grad():
for idx, (data, labels) in enumerate(tqdm(test_loader, total=len(test_loader), desc=f"Testing class: {category}", leave=False)):
data = data.to(device)
Expand Down Expand Up @@ -349,7 +356,7 @@ def eval_ad_loss(zipped: dict, c: dict, R: dict, nu: float, boundary: str) -> [d
loss : torch.Tensor
Trainign loss
""""
"""
dist = {}

loss = 1
Expand Down
Loading

0 comments on commit c1ec0a7

Please sign in to comment.