Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end training of ViT #5

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8ce2265
supervised training
piotrmwojcik Sep 11, 2023
f22ea94
supervised training
piotrmwojcik Sep 11, 2023
86d0377
supervised training
piotrmwojcik Sep 11, 2023
990828b
supervised training
piotrmwojcik Sep 11, 2023
d4e5b73
supervised training
piotrmwojcik Sep 11, 2023
598c028
supervised training
piotrmwojcik Sep 11, 2023
7f03483
random M and K
piotrmwojcik Sep 12, 2023
bbf6f57
random M and K
piotrmwojcik Sep 12, 2023
6b33b95
random M and K
piotrmwojcik Sep 12, 2023
e1e611f
random M and K
piotrmwojcik Sep 12, 2023
ab683f8
random M and K
piotrmwojcik Sep 12, 2023
33f709d
random M and K
piotrmwojcik Sep 12, 2023
0da1203
random M and K
piotrmwojcik Sep 12, 2023
8ba74bc
random M and K
piotrmwojcik Sep 12, 2023
73b7f89
random M and K
piotrmwojcik Sep 12, 2023
7971b75
random M and K
piotrmwojcik Sep 12, 2023
481a5fd
random M and K
piotrmwojcik Sep 12, 2023
062ea4b
random M and K
piotrmwojcik Sep 12, 2023
7727fa8
random M and K
piotrmwojcik Sep 12, 2023
777ac3c
random M and K
piotrmwojcik Sep 12, 2023
1ae4d68
random M and K
piotrmwojcik Sep 12, 2023
ccd51fb
random M and K
piotrmwojcik Sep 12, 2023
d952107
random M and K
piotrmwojcik Sep 12, 2023
1c7f3f1
random M and K
piotrmwojcik Sep 12, 2023
34c5451
random M and K
piotrmwojcik Sep 12, 2023
8cf03a5
random M and K
piotrmwojcik Sep 12, 2023
c1cf255
random M and K
piotrmwojcik Sep 12, 2023
0973951
random M and K
piotrmwojcik Sep 12, 2023
5c5e214
random M and K
piotrmwojcik Sep 12, 2023
e4981f4
random M and K
piotrmwojcik Sep 12, 2023
a8184fa
random M and K
piotrmwojcik Sep 12, 2023
a4a359b
random M and K
piotrmwojcik Sep 12, 2023
953f3b2
random M and K
piotrmwojcik Sep 12, 2023
1dd70b1
random M and K
piotrmwojcik Sep 12, 2023
0df9c05
random M and K
piotrmwojcik Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions compvits/main_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import json
import random
import sys
from pathlib import Path

import torch
from torch import nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torchvision import transforms as pth_transforms
from torchvision import models as torchvision_models

import utils
import vision_transformer as vits


def eval_linear(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
cudnn.benchmark = True

# ============ building network ... ============
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if args.arch in vits.__dict__.keys():
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
# if the network is a XCiT
elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
embed_dim = model.embed_dim
# otherwise, we check if the architecture is in torchvision models
elif args.arch in torchvision_models.__dict__.keys():
model = torchvision_models.__dict__[args.arch]()
embed_dim = model.fc.weight.shape[1]
model.fc = nn.Identity()
else:
print(f"Unknow architecture: {args.arch}")
sys.exit(1)
model.cuda()
model.eval()
# load weights to evaluate
#utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
print(f"Model {args.arch} built.")

linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

# ============ preparing data ... ============
val_transform = pth_transforms.Compose([
pth_transforms.Resize(256, interpolation=3),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
val_loader = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)

if args.evaluate:
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

train_transform = pth_transforms.Compose([
pth_transforms.RandomResizedCrop(224),
pth_transforms.RandomHorizontalFlip(),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

# set optimizer
optimizer = torch.optim.SGD(
linear_classifier.parameters(),
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
momentum=0.9,
weight_decay=0, # we do not apply weight decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

# Optionally resume from a checkpoint
to_restore = {"epoch": 0, "best_acc": 0.}
utils.restart_from_checkpoint(
os.path.join(args.output_dir, "checkpoint.pth.tar"),
run_variables=to_restore,
state_dict=linear_classifier,
optimizer=optimizer,
scheduler=scheduler,
)
start_epoch = to_restore["epoch"]
best_acc = to_restore["best_acc"]

for epoch in range(start_epoch, args.epochs):
train_loader.sampler.set_epoch(epoch)

train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
scheduler.step()

log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch}
if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
best_acc = max(best_acc, test_stats["acc1"])
print(f'Max accuracy so far: {best_acc:.2f}%')
log_stats = {**{k: v for k, v in log_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()}}
if utils.is_main_process():
with (Path(args.output_dir) / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
save_dict = {
"epoch": epoch + 1,
"state_dict": linear_classifier.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_acc": best_acc,
}
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
print("Training of the supervised linear classifier on frozen features completed.\n"
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))


def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
linear_classifier.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
for (inp, target) in metric_logger.log_every(loader, 20, header):
# move to gpu
inp = inp.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)

# forward
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
output = output.reshape(output.shape[0], -1)
else:
K = random.randint(0, 12)
M = random.choice([2, 3, 4, 6, 8, 9, 12, 16])
output = model(inp, K, M)
output = linear_classifier(output)

# compute cross entropy loss
loss = nn.CrossEntropyLoss()(output, target)

# compute the gradients
optimizer.zero_grad()
loss.backward()

# step
optimizer.step()

# log
torch.cuda.synchronize()
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def validate_network(val_loader, model, linear_classifier, n, avgpool):
linear_classifier.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for inp, target in metric_logger.log_every(val_loader, 20, header):
# move to gpu
inp = inp.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)

# forward
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
output = output.reshape(output.shape[0], -1)
else:
K = 0
M = random.choice([2, 3, 4, 6, 8, 9, 12, 16])
output = model(inp, K, M)
output = linear_classifier(output)
loss = nn.CrossEntropyLoss()(output, target)

if linear_classifier.module.num_labels >= 5:
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
else:
acc1, = utils.accuracy(output, target, topk=(1,))

batch_size = inp.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
if linear_classifier.module.num_labels >= 5:
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
if linear_classifier.module.num_labels >= 5:
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
else:
print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


class LinearClassifier(nn.Module):
"""Linear layer to train on top of frozen features"""
def __init__(self, dim, num_labels=1000):
super(LinearClassifier, self).__init__()
self.num_labels = num_labels
self.linear = nn.Linear(dim, num_labels)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()

def forward(self, x):
# flatten
x = x.view(x.size(0), -1)

# linear layer
return self.linear(x)


if __name__ == '__main__':
parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
parser.add_argument('--n_last_blocks', default=1, type=int, help="""Concatenate [CLS] tokens
for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
We typically set this to False for ViT-Small and to True with ViT-Base.""")
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
training (highest LR used during training). The learning rate is linearly scaled
with the batch size, and specified here for a reference batch size of 256.
We recommend tweaking the LR depending on the checkpoint evaluated.""")
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
args = parser.parse_args()
eval_linear(args)
2 changes: 1 addition & 1 deletion compvits/scripts/extract_train_features.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ python tools/run_distributed_engines.py \
engine_name=extract_features \
config.CHECKPOINT.DIR=$dir \
config.TEST_MODEL=False \
config.MODEL.WEIGHTS_INIT.PARAMS_FILE=/home/jan.olszewski/git/vissl/checkpoints/${model}.pth \
config.MODEL.WEIGHTS_INIT.PARAMS_FILE=/data/pwojcik/vissl/checkpoints/${model}.pth \
config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME=model \
8 changes: 4 additions & 4 deletions compvits/scripts/nearest_neighbor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ else
fi

mkdir --parents $dir
mv ${feats_all}/rank0_chunk0_train*.npy ${dir}
mv ${feats_K}/rank0_chunk0_test*.npy ${dir}
cp ${feats_all}/rank0_chunk0_train*.npy ${dir}
cp ${feats_K}/rank0_chunk0_test*.npy ${dir}

python tools/nearest_neighbor_test.py \
config=compvits/base \
Expand All @@ -27,5 +27,5 @@ python tools/nearest_neighbor_test.py \
config.CHECKPOINT.DIR=$dir \
config.NEAREST_NEIGHBOR.FEATURES.PATH=$dir \

mv ${dir}/rank0_chunk0_train*.npy ${feats_all}
mv ${dir}/rank0_chunk0_test*.npy ${feats_K}
#mv ${dir}/rank0_chunk0_train*.npy ${feats_all}
#mv ${dir}/rank0_chunk0_test*.npy ${feats_K}
2 changes: 1 addition & 1 deletion compvits/scripts/run_sweep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

source compvits/scripts/extract_train_features.sh

scripts=(compvits/scripts/test_linear.sh compvits/scripts/extract_features.sh compvits/scripts/nearest_neighbor.sh)
scripts=(compvits/scripts/nearest_neighbor.sh)
models=(deitb)
Ms=(2 3 4 6 8 9 12 16)
for script in ${scripts[@]}; do
Expand Down
Loading