Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TinyCLIP] inference auto weight inheritance #223

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions TinyCLIP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
[TinyCLIP ResNet-19M Text-19M](./src/open_clip/model_configs/TinyCLIP-ResNet-19M-Text-19M.json) | manual | LAION-400M | 56.4 | 4.4 | 3,024| [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ResNet-19M-Text-19M-LAION400M.pt)
[TinyCLIP ViT-61M/32 Text-29M](./src/open_clip/model_configs/TinyCLIP-ViT-61M-32-Text-29M.json) | manual | LAION-400M | 62.4 | 5.3 | 3,191|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-61M-32-Text-29M-LAION400M.pt)
[TinyCLIP ViT-40M/32 Text-19M](./src/open_clip/model_configs/TinyCLIP-ViT-40M-32-Text-19M.json) | manual | LAION-400M | 59.8 | 3.5 | 4,641|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-40M-32-Text-19M-LAION400M.pt)
TinyCLIP ViT-63M/32 Text-31M | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt)
TinyCLIP ViT-45M/32 Text-18M | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt)
TinyCLIP ViT-22M/32 Text-10M | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt)
TinyCLIP ViT-63M/32 Text-31M | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt)
TinyCLIP ViT-45M/32 Text-18M | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt)
[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION-400M | 63.9 | 5.6 | 2,905|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt)
[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION-400M | 61.4 | 3.7 | 3,682|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt)
[TinyCLIP ViT-22M/32 Text-10M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-22M-32-Text-10M.json) | auto | LAION-400M | 53.7 | 1.9 | 5,504|[Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt)
[TinyCLIP ViT-63M/32 Text-31M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-63M-32-Text-31M.json) | auto | LAION+YFCC-400M | 64.5 | 5.6| 2,909 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt)
[TinyCLIP ViT-45M/32 Text-18M](./src/open_clip/model_configs/TinyCLIP-auto-ViT-45M-32-Text-18M.json) | auto | LAION+YFCC-400M | 62.7 | 1.9 | 3,685 | [Model](https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt)

Note: The configs of models with auto inheritance are generated automatically.

Expand Down
20 changes: 18 additions & 2 deletions TinyCLIP/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,24 @@
# arch = 'TinyCLIP-ViT-61M-32-Text-29M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')

arch = 'TinyCLIP-ViT-40M-32-Text-19M'
model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')
# arch = 'TinyCLIP-ViT-40M-32-Text-19M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')

# auto inheritance
# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')

# arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')

# arch = 'TinyCLIP-auto-ViT-22M-32-Text-10M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAION400M')

# arch = 'TinyCLIP-auto-ViT-63M-32-Text-31M'
# model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M')

arch = 'TinyCLIP-auto-ViT-45M-32-Text-18M'
model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained='LAIONYFCC400M')

tokenizer = open_clip.get_tokenizer(arch)

Expand Down
17 changes: 16 additions & 1 deletion TinyCLIP/src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, convert_weights_to_fp16, resize_pos_embed
from .model import load_pruned_model, prune_model
from .openai import load_openai_model
from .pretrained import get_pretrained_cfg, download_pretrained
from .transform import image_transform
Expand Down Expand Up @@ -86,6 +87,13 @@ def load_checkpoint(model, checkpoint_path, strict=True):
return incompatible_keys


def load_pruned_checkpoint(model, checkpoint_path, strict=True):
state_dict = load_state_dict(checkpoint_path)
resize_pos_embed(state_dict, model)
incompatible_keys = load_pruned_model(model, state_dict, strict=strict)
return incompatible_keys


def create_model(
model_name: str,
pretrained: str = '',
Expand Down Expand Up @@ -138,6 +146,9 @@ def create_model(
f'model sparsity varies from {model_cfg["start_sparsity"]} to {model_cfg["sparsity"]}, sparsity warmup steps: {model_cfg["sparsity_warmup"]}')

logging.info(str(model_cfg))
auto_weight_inheritance = model_cfg.get('mask_image', False) or \
model_cfg.get('mask_text', False)

model = CLIP(**model_cfg)

pretrained_cfg = {}
Expand All @@ -153,7 +164,11 @@ def create_model(
if checkpoint_path:
logging.info(
f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
if not auto_weight_inheritance:
load_checkpoint(model, checkpoint_path)
else:
load_pruned_checkpoint(model, checkpoint_path)
model = prune_model(model)
else:
logging.warning(
f'Pretrained weights ({pretrained}) not found for model {model_name}.')
Expand Down
34 changes: 32 additions & 2 deletions TinyCLIP/src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=


@torch.no_grad()
def load_pruned_model(model, pruned_state_dict):
def load_pruned_model(model, pruned_state_dict, strict=True):
'''
A full model loads the pruned state dict.

Expand All @@ -1315,6 +1315,10 @@ def _copy_to_full_weight(dst, src):
slices = [slice(0, d) for d in dims]
dst[slices].copy_(src)

for _ in range(2):
pruned_state_dict = {
k.replace('module.', ''): v for k, v in pruned_state_dict.items()}

lambda_init_value = 10.0
model_state_dict = model.state_dict()
head_dim = model.transformer.head_dim
Expand Down Expand Up @@ -1405,4 +1409,30 @@ def _get_layer_id(name):
model_state_dict[f'{ename}.l0_module.intermediate_loga'][d,
:].fill_(-lambda_init_value)

model.load_state_dict(model_state_dict, strict=True)
return model.load_state_dict(model_state_dict, strict=strict)


def prune_model(model):
device = next(model.parameters()).device

with torch.no_grad():
model.image_encoder_without_ddp.eval()
image_size = (1, 3) + model.image_encoder_without_ddp.visual.image_size
image = torch.randn(image_size, device=device)
model.image_encoder_without_ddp(image)
model.image_encoder_without_ddp = model.image_encoder_without_ddp.prune()

assert hasattr(model.image_encoder_without_ddp, 'l0_module')
model.image_encoder_without_ddp.l0_module = None

with torch.no_grad():
model.text_encoder_without_ddp.eval()
context_length = model.text_encoder_without_ddp.context_length
text = torch.zeros((1, context_length), dtype=torch.long, device=device)
model.text_encoder_without_ddp(text)
model.text_encoder_without_ddp = model.text_encoder_without_ddp.prune()

assert hasattr(model.text_encoder_without_ddp, 'l0_module')
model.text_encoder_without_ddp.l0_module = None

return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"mask_image": true,
"mask_text": true,

"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"mask_image": true,
"mask_text": true,

"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"mask_image": true,
"mask_text": true,

"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
33 changes: 33 additions & 0 deletions TinyCLIP/src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def _pcfg(url='', hf_hub='', mean=None, std=None):
)

# TinyCLIP

# manual weight inheritance

_TINYCLIP_VIT_39M_16_TEXT_19M = {
"YFCC15M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-39M-16-Text-19M-YFCC15M.pt",
Expand Down Expand Up @@ -182,6 +185,32 @@ def _pcfg(url='', hf_hub='', mean=None, std=None):
),
}

# auto weight inheritance

_TINYCLIP_AUTO_VIT_63M_32_TEXT_31M = {
"LAION400M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAION400M.pt",
),
"LAIONYFCC400M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-63M-32-Text-31M-LAIONYFCC400M.pt",
),
}

_TINYCLIP_AUTO_VIT_45M_32_TEXT_18M = {
"LAION400M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAION400M.pt",
),
"LAIONYFCC400M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-45M-32-Text-18M-LAIONYFCC400M.pt",
),
}

_TINYCLIP_AUTO_VIT_22M_32_TEXT_10M = {
"LAION400M": _pcfg(
"https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-auto-ViT-22M-32-Text-10M-LAION400M.pt",
),
}

_PRETRAINED = {
"RN50": _RN50,
"RN50-quickgelu": _RN50_quickgelu,
Expand All @@ -205,6 +234,10 @@ def _pcfg(url='', hf_hub='', mean=None, std=None):
"TinyCLIP-ResNet-19M-Text-19M": _TINYCLIP_RESNET_19M_TEXT_19M,
"TinyCLIP-ViT-61M-32-Text-29M": _TINYCLIP_VIT_61M_32_TEXT_29M,
"TinyCLIP-ViT-40M-32-Text-19M": _TINYCLIP_VIT_40M_32_TEXT_19M,

"TinyCLIP-auto-ViT-63M-32-Text-31M": _TINYCLIP_AUTO_VIT_63M_32_TEXT_31M,
"TinyCLIP-auto-ViT-45M-32-Text-18M": _TINYCLIP_AUTO_VIT_45M_32_TEXT_18M,
"TinyCLIP-auto-ViT-22M-32-Text-10M": _TINYCLIP_AUTO_VIT_22M_32_TEXT_10M,
}


Expand Down