Skip to content

Commit

Permalink
modified method to initialize hyperspheres centers
Browse files Browse the repository at this point in the history
  • Loading branch information
fvmassoli committed Mar 1, 2021
1 parent b5adb9f commit 8b83e10
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 34 deletions.
43 changes: 22 additions & 21 deletions main_mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def test_models(test_loader: DataLoader, net_cehckpoint: str, tables: tuple, out


def main(args):
# Set seed
set_seeds(args.seed)

# Get the device
device = "cuda" if torch.cuda.is_available() else "cpu"

if args.disable_logging:
logging.disable(level=logging.INFO)

Expand All @@ -151,7 +157,6 @@ def main(args):
logging.FileHandler('./training.log'),
logging.StreamHandler()
])

logger = logging.getLogger()

if args.train or args.pretrain:
Expand Down Expand Up @@ -196,12 +201,6 @@ def main(args):

else:
args.normal_class = args.model_ckp.split('/')[-3]

# Set seed
set_seeds(args.seed)

# Get the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Init DataHolder class
data_holder = DataManager(
Expand Down Expand Up @@ -243,9 +242,10 @@ def main(args):

### PRETRAIN the full AutoEncoder
ae_net_cehckpoint = None
if args.pretrain:
out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='mvtec')
tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/pretrain', tmp))
if args.pretrain:

pretrain_out_dir, tmp = get_out_dir(args, pretrain=True, aelr=None, dset_name='mvtec')
pretrain_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/pretrain', tmp))

# Init AutoEncoder
ae_net = MVTecNet_AutoEncoder(input_shape=input_shape, code_length=args.code_length, use_selectors=args.use_selectors)
Expand All @@ -254,8 +254,8 @@ def main(args):
ae_net_cehckpoint = pretrain(
ae_net=ae_net,
train_loader=train_loader,
out_dir=out_dir,
tb_writer=tb_writer,
out_dir=pretrain_out_dir,
tb_writer=pretrain_tb_writer,
device=device,
ae_learning_rate=args.ae_learning_rate,
ae_weight_decay=args.ae_weight_decay,
Expand All @@ -266,7 +266,7 @@ def main(args):
debug=args.debug
)

tb_writer.close()
pretrain_tb_writer.close()

### TRAIN the Encoder
net_cehckpoint = None
Expand All @@ -280,9 +280,8 @@ def main(args):

aelr = float(ae_net_cehckpoint.split('/')[-2].split('-')[4].split('_')[-1])

out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='mvtec')

tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/train', tmp))
train_out_dir, tmp = get_out_dir(args, pretrain=False, aelr=aelr, dset_name='mvtec')
train_tb_writer = SummaryWriter(os.path.join(args.output_path, 'mvtec', str(args.normal_class), 'tb_runs/train', tmp))

# Init the Encoder network
encoder_net = load_mvtec_model_from_checkpoint(
Expand All @@ -295,15 +294,17 @@ def main(args):
)

## Eval/Load hyperspeheres centers
centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, device=device, debug=args.debug)

encoder_net.set_idx_list_enc(range(8))
centers = eval_spheres_centers(train_loader=train_loader, encoder_net=encoder_net, ae_net_cehckpoint=ae_net_cehckpoint, use_selectors=args.use_selectors, device=device, debug=args.debug)
encoder_net.set_idx_list_enc(args.idx_list_enc)

# Start training
net_cehckpoint = train(
net=encoder_net,
train_loader=train_loader,
centers=centers,
out_dir=out_dir,
tb_writer=tb_writer,
out_dir=train_out_dir,
tb_writer=train_tb_writer,
device=device,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
Expand All @@ -317,7 +318,7 @@ def main(args):
debug=args.debug
)

tb_writer.close()
train_tb_writer.close()

### TEST the Encoder
if args.test:
Expand Down
22 changes: 18 additions & 4 deletions models/mvtec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ def get_depths_info(self) -> [int, int]:
"""
return self.last_depth, self.deepest_shape

def set_idx_list_enc(self, idx_list_enc: list) -> None:
"""Set the list of layers from wchich extract the features.
It is used to initialize the hyperspheres centers so that
independently from which layers we are considering, the first
time that we create the centroids, we do it for all the layers.
Parameters
----------
idx_list_enc : list
List of layers indices
"""
self.idx_list_enc = idx_list_enc

def forward(self, x: torch.Tensor) -> torch.Tensor:
o1 = self.conv(x)
o2 = self.res(self.activation_fn(o1))
Expand All @@ -159,10 +173,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs = [o1, o2, o3, o4, o5, o7, o8, z]

if len(self.idx_list_enc) != 0:
# If we are pretraining the full AutoEncoder we don't need any of this and we set self.idx_list_enc = []

if self.use_selectors:
tuple_o = [self.selectors[idx](tt) for idx, tt in enumerate(outputs) if idx in self.idx_list_enc]

else:
# If we don't use selector, apply simple transformations to reduce the size of the feature maps
tuple_o = []

for idx, tt in enumerate(outputs):
Expand All @@ -174,10 +191,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
tuple_o.append(tt.squeeze())

names = [f'0{idx}' for idx in self.idx_list_enc]
zipped = list(zip(names, tuple_o))

return zipped
return list(zip([f'0{idx}' for idx in self.idx_list_enc], tuple_o))

else: # It means that we are pretraining the full AutoEncoder
return z
Expand Down
8 changes: 4 additions & 4 deletions trainers/trainer_mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def pretrain(ae_net: nn.Module, train_loader: DataLoader, out_dir: str, tb_write
optimizer.zero_grad()

for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
if debug and idx == 10: break
if debug and idx == 5: break

data = data.to(device)

Expand Down Expand Up @@ -184,7 +184,7 @@ def train(net: torch.nn.Module, train_loader: DataLoader, centers: dict, out_dir
optimizer.zero_grad()

for idx, (data, _) in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
if debug and idx == 10 : break
if debug and idx == 5: break

data = data.to(device)

Expand Down Expand Up @@ -301,7 +301,7 @@ def test(normal_class: str, is_texture: bool, net: nn.Module, test_loader: DataL

with torch.no_grad():
for idx, (data, labels) in enumerate(tqdm(test_loader, total=len(test_loader), desc=f"Testing class: {normal_class}", leave=False)):
if debug and idx == 3: break
if debug and idx == 5: break

data = data.to(device)

Expand Down Expand Up @@ -381,7 +381,7 @@ def eval_ad_loss(zipped: dict, c: dict, R: dict, nu: float, boundary: str) -> [d
dist = {}

loss = 1

for (k, v) in zipped:
dist[k] = torch.sum((v - c[k].unsqueeze(0)) ** 2, dim=1)

Expand Down
16 changes: 11 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_mvtec_model_from_checkpoint(input_shape: tuple, code_length: int, idx_l
return encoder_net


def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module, ae_net_cehckpoint: str, device:str, debug: bool) -> dict:
def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module, ae_net_cehckpoint: str, use_selectors: bool, device:str, debug: bool) -> dict:
"""Eval the centers of the hyperspheres at each chosen layer.
Parameters
Expand All @@ -158,6 +158,10 @@ def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module,
Encoder network
ae_net_cehckpoint : str
Checkpoint of the full AutoEncoder
use_selectors : bool
True if we want to use selector models
device : str
Device on which run the computations
debug : bool
Activate debug mode
Expand All @@ -169,11 +173,13 @@ def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module,
"""
logger = logging.getLogger()

centers_files = ae_net_cehckpoint[:-4]+f'_w_centers_{use_selectors}.pth'

# If centers are found, then load and return
if os.path.exists(ae_net_cehckpoint[:-4]+'_w_centers.pth'):
if os.path.exists(centers_files):

logger.info("Found hyperspheres centers")
ae_net_ckp = torch.load(ae_net_cehckpoint[:-4]+'_w_centers.pth', map_location=lambda storage, loc: storage)
ae_net_ckp = torch.load(centers_files, map_location=lambda storage, loc: storage)

centers = {k: v.to(device) for k, v in ae_net_ckp['centers'].items()}
else:
Expand All @@ -182,7 +188,7 @@ def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module,
centers_ = init_center_c(train_loader=train_loader, encoder_net=encoder_net, device=device, debug=debug)

logger.info("Hyperspheres centers evaluated!!!")
new_ckp = ae_net_cehckpoint.split('.pth')[0]+'_w_centers.pth'
new_ckp = ae_net_cehckpoint.split('.pth')[0]+f'_w_centers_{use_selectors}.pth'

logger.info(f"New AE dict saved at: {new_ckp}!!!")
centers = {k: v for k, v in centers_.items()}
Expand Down Expand Up @@ -217,7 +223,7 @@ def init_center_c(train_loader: DataLoader, encoder_net: torch.nn.Module, device
encoder_net.eval().to(device)

for idx, (data, _) in enumerate(tqdm(train_loader, desc='Init hyperspheres centeres', total=len(train_loader), leave=False)):
if debug and idx == 10: break
if debug and idx == 5: break

data = data.to(device)
n_samples += data.shape[0]
Expand Down

0 comments on commit 8b83e10

Please sign in to comment.