Skip to content

Releases: marsggbo/hyperbox

hyperbox_proxyless_supernet_and_subnets

20 Jun 13:30
Compare
Choose a tag to compare

1. load weights from pretrained supernet

from hyperbox.networks.proxylessnas.network  import ProxylessNAS
supernet = ProxylessNAS(
    width_stages=[30, 40, 80, 96, 182, 320],
    n_cell_stages=[4, 4, 4, 4, 4, 1],
    stride_stages=[2, 2, 2, 1, 2, 1],
    width_mult=1.4,
    num_classes=1000,
    dropout_rate=0,
    bn_param=[0.1, 0.001],
)
ckpt = torch.load('/path/to/hyperbox_proxy_mobile_w1.4.pth',map_location='cpu')
supernet.load_state_dict(ckpt)

2. load weights from pretrained subnet

  • subnet weight 1: test accuracy on ImageNet is 77.15%
from hyperbox.networks.proxylessnas.network  import ProxylessNAS
subnet= ProxylessNAS(
    width_stages=[30, 40, 80, 96, 182, 320],
    n_cell_stages=[4, 4, 4, 4, 4, 1],
    stride_stages=[2, 2, 2, 1, 2, 1],
    width_mult=1.4,
    num_classes=1000,
    dropout_rate=0,
    bn_param=[0.1, 0.001],
    mask='subnet_acc77.15.json'
)
ckpt = torch.load('/path/to/hyperbox_proxylessnas_w1.4_acc77.15_subnet.pth',map_location='cpu')
subnet.load_state_dict(ckpt)
  • subnet weight 2: test accuracy on ImageNet is 77.21%
from hyperbox.networks.proxylessnas.network  import ProxylessNAS
subnet= ProxylessNAS(
    width_stages=[30, 40, 80, 96, 182, 320],
    n_cell_stages=[4, 4, 4, 4, 4, 1],
    stride_stages=[2, 2, 2, 1, 2, 1],
    width_mult=1.4,
    num_classes=1000,
    dropout_rate=0,
    bn_param=[0.1, 0.001],
    mask='subnet_acc77.21.json'
)
ckpt = torch.load('/path/to/hyperbox_proxylessnas_w1.4_acc77.21_subnet.pth',map_location='cpu')
subnet.load_state_dict(ckpt)

hyperbox_OFA_MBV3_k357_d234_e346_w1.2.pth

20 Jun 12:38
Compare
Choose a tag to compare
from hyperbox.networks.ofa.ofa_mbv3 import OFAMobileNetV3
supernet = OFAMobileNetV3(
    first_stride=2,
    kernel_size_list=[3, 5, 7],
    expand_ratio_list=[4, 6],
    depth_list=[3, 4],
    base_stage_width=[16, 16, 24, 40, 80, 112, 160, 960, 1280],
    stride_stages=[1, 2, 2, 2, 1, 2],
    act_stages=['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish'],
    se_stages=[False, False, True, False, True, True],
    width_mult=1.2,
    num_classes=1000,
)
ckpt = torch.load('/path/to/hyperbox_OFA_MBV3_k357_d234_e346_w1.2.pth', map_location='cpu')
supernet.load_state_dict(ckpt)

vit pre-trained weights

11 Apr 13:35
Compare
Choose a tag to compare

1. Convert pre-trained weights from https://github.com/jeonsworld/ViT-pytorch

  • git clone repo
git clone https://github.com/jeonsworld/ViT-pytorch
  • download weights
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
  • weights conversion
from hyperbox.networks.vit import ViT_B

from models.modeling import VisionTransformer, CONFIGS
from functools import partial
ViT_B_16 = partial(VisionTransformer, config=CONFIGS['ViT-B_16'])

def sync_params(net1, net2):
    """
    Args:
        net1: src net
        net2: tgt net
    """        
    count_size = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_params1 = count_size(net1)
    num_params2 = count_size(net2)
    print(f"ViT-B_16: {num_params1} | ViT-B: {num_params2}")

    num_sync_params = 0
    net2.vit_embed.load_state_dict(net1.transformer.embeddings.state_dict())
    num_sync_params += count_size(net2.vit_embed)
    
    assert len(net1.transformer.encoder.layer)==len(net2.vit_blocks)
    for i in range(len(net1.transformer.encoder.layer)):
        layer1 = net1.transformer.encoder.layer[i]
        layer2 = net2.vit_blocks[i]
        layer2.attn.block.norm
        
        # attn
        ## norm
        attention_norm1 = layer1.attention_norm
        attention_norm2 = layer2.attn.block.norm
        attention_norm2.load_state_dict(attention_norm1.state_dict())
        num_sync_params += count_size(attention_norm2)
        
        ## qkv
        q = layer1.attn.query
        k = layer1.attn.key
        v = layer1.attn.value
        qkv = layer2.attn.block.fn.to_qkv
        qkv.weight.data.copy_(torch.cat([q.weight, k.weight, v.weight], dim=0).data)
        if qkv.bias is not None:
            qkv.bias.data.copy_(torch.cat([q.bias, k.bias, v.bias], dim=0))
        num_sync_params += count_size(qkv)

        ## fc
        out1 = layer1.attn.out
        out2 = layer2.attn.block.fn.to_out[0]
        out2.load_state_dict(out1.state_dict())
        num_sync_params += count_size(out2)

        # ff
        ## norm
        ffn_norm1 = layer1.ffn_norm
        ffn_norm2 = layer2.ff.block.norm
        ffn_norm2.load_state_dict(ffn_norm1.state_dict())
        num_sync_params += count_size(ffn_norm2)
        
        ## fc
        mlp_fc11 = layer1.ffn.fc1
        mlp_fc12 = layer1.ffn.fc2
        mlp_fc21 = layer2.ff.block.fn.net[0]
        mlp_fc22 = layer2.ff.block.fn.net[3]
        mlp_fc21.load_state_dict(mlp_fc11.state_dict())
        mlp_fc22.load_state_dict(mlp_fc12.state_dict())
        num_sync_params += count_size(mlp_fc21)
        num_sync_params += count_size(mlp_fc22)
    
    # head
    ## norm
    norm1 = net1.transformer.encoder.encoder_norm
    norm2 = net2.vit_cls_head.mlp_head[0]
    norm2.load_state_dict(norm1.state_dict())
    num_sync_params += count_size(norm2)
    
    ## fc
    fc1 = net1.head
    fc2 = net2.vit_cls_head.mlp_head[1]
    try:
        fc2.load_state_dict(fc1.state_dict())
        num_sync_params += count_size(fc2)
    except:
        pass
    print(f"sync params: {num_sync_params}")  

net1 = ViT_B_16()
net1.load_from(np.load('/path/to/ViT-B_16.npz'))

net2 = ViT_B()
sync_params(net1, net2)
torch.save(net2.state_dict(), 'vit_b.pth')

2. Validate the pretrained weights

import torch
import torchvision
import torchvision.transforms as transforms

from hyperbox.networks.vit import ViT_B, ViT_L

def testloader(data_path='/path/to/imagenet2012/val', batch_size=400, num_workers=4):
    """Create test dataloader for ImageNet dataset."""
    # Define data transforms
    data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Load the test dataset
    test_dataset = torchvision.datasets.ImageFolder(root=data_path, transform=data_transforms)

    # Create a test dataloader
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return test_loader

def validate(loader, model, criterion, device, verbose=True):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loader):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            output = model(x)
            loss = criterion(output, y)
            test_loss += loss.item()

            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(y.view_as(pred)).sum().item()
            if batch_idx % 10 == 0 and verbose:
                print('Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(x), len(loader.dataset),
                    100. * batch_idx / len(loader), loss.item()))
    test_loss /= len(loader.dataset)
    test_acc = 100. * correct / len(loader.dataset)

    if verbose:
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(loader.dataset), test_acc))
    return test_loss, test_acc

net = ViT_B()
net.load_state_dict('vit_b.pth')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
loader = testloader()
criterion = torch.nn.CrossEntropyLoss()
validate(loader, net.to(device), criterion, device)

3. Results

Model pth Dataset Acc@top1
ViT_B vit_b.pth ImageNet1k 75.66%
ViT_B(patch_size=32) vit_b_32.pth ImageNet1k 64.44%
ViT_L vit_L.pth ImageNet1k 79.25%
ViT_L(patch_size=32) vit_L_32.pth ImageNet1k 65.19%
ViT_H(patch_size=14, num_classes=21843) vit_H_14.pth ImageNet21k -

The pretrained weights for ViT_H_14 (patch size=14) is too large (~2.45G), so that we split it into multiple smaller chunks, i.e., vit_H_14.pth.parta*. To use the full weights, you can cat them into a single pth file after downloading them:

cat vit_H_14.pth.part* > vit_H_14.pth

By default, the series of ViT models provided by hyperbox use patch size of 16. To use vit_H_14.pth, you may need to build a model by modifying:

  • patch_size=14
  • num_classes=21843 (the model is pretrained based on ImageNet21k)
import torch
from hyperbox.networks.vit import ViT_H
vit_h_14 = ViT_H(patch_size=14, num_classes=21843)
vit_h_14.load_state_dict(torch.load('vit_H_14.pth'))

OFA_MBV3_k357_d234_e46_w1.pth

24 Dec 11:38
Compare
Choose a tag to compare
  • hyperbox-based OFA-MobileNetV3
import torch
from hyperbox.networks.ofa import OFAMobileNetV3

device = 'cuda' if torch.cuda.is_available() else 'cpu'

weight = torch.load('path/to/hyperbox_OFA_MBV3_k357_d234_e46_w1.pth')
supernet = OFAMobileNetV3()
supernet.load_state_dict(weight)
mask = supernet.gen_mask(depth=4, expand_ratio=6, kernel_size=7)

net = supernet.build_subnet(mask).to(device)
  • Official OFA-MobileNetV3
# https://github.com/mit-han-lab/once-for-all
import torch
from ofa.imagenet_classification.elastic_nn.networks.ofa_mbv3 import OFAMobileNetV3

device = 'cuda' if torch.cuda.is_available() else 'cpu'

weight = torch.load('path/to/official_OFA_MBV3_k357_d234_e46_w1.pth')
supernet = OFAMobileNetV3(dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4])
supernet.load_state_dict(weight)
supernet.set_active_subnet(ks=7, e=6, d=4)

net = supernet.get_active_subnet(preserve_weight=True).to(device)