From 837db818da289552de432fbbcac5822136f6c682 Mon Sep 17 00:00:00 2001 From: Anastasiais-ml Date: Wed, 16 Oct 2024 22:38:09 +0200 Subject: [PATCH 1/4] Add GLIP (Centered Masking for Language-Image Pre-Training) --- src/open_clip/factory.py | 14 ++++++++++++++ src/open_clip/model.py | 4 ++++ src/open_clip/transformer.py | 29 +++++++++++++++++++++++++++-- src/open_clip_train/main.py | 2 ++ src/open_clip_train/params.py | 12 ++++++++++++ 5 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 82ebe2bb9..c29f59610 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -201,6 +201,8 @@ def create_model( force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, + gaussian_masking: Optional[bool] = None, + gaussian_masking_std: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, force_preprocess_cfg: Optional[Dict[str, Any]] = None, pretrained_image: bool = False, @@ -251,6 +253,14 @@ def create_model( if force_patch_dropout is not None: # override the default patch dropout value model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if gaussian_masking is not None: + # override the default gaussian masking value + model_cfg["vision_cfg"]["gaussian_masking"] = gaussian_masking + + if gaussian_masking_std is not None: + # override the default gaussian masking std value + model_cfg["vision_cfg"]["gaussian_masking_std"] = gaussian_masking_std if force_image_size is not None: # override model config's image size @@ -396,6 +406,8 @@ def create_model_and_transforms( force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, + gaussian_masking: Optional[bool] = None, + gaussian_masking_std: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, @@ -420,6 +432,8 @@ def create_model_and_transforms( force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_patch_dropout=force_patch_dropout, + gaussian_masking=gaussian_masking, + gaussian_masking_std=gaussian_masking_std, force_image_size=force_image_size, force_preprocess_cfg=force_preprocess_cfg, pretrained_image=pretrained_image, diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 989662ebb..cc8ef74ff 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -34,6 +34,8 @@ class CLIPVisionCfg: ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + gaussian_masking: bool = False + gaussian_masking_std: float = 0.20 attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling @@ -155,6 +157,8 @@ def _build_vision_tower( mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, + gaussian_masking=vision_cfg.gaussian_masking, + gaussian_masking_std=vision_cfg.gaussian_masking_std, attentional_pool=vision_cfg.attentional_pool, attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index bf85dc8e7..0b7054f89 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -7,6 +7,8 @@ from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint +from scipy.stats import multivariate_normal +import numpy as np from .utils import to_2tuple from .pos_embed import get_2d_sincos_pos_embed @@ -51,11 +53,29 @@ class PatchDropout(nn.Module): https://arxiv.org/abs/2212.00794 """ - def __init__(self, prob, exclude_first_token=True): + def __init__(self, prob, grid_size, gaussian_masking=False, std=0.20, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token + self.gaussian_masking = gaussian_masking + self.pdf_values = self.normal_distribution(grid_size, std=std) + + + def normal_distribution(self, grid_size=14, mean=0, std=0.20): + """ + https://arxiv.org/abs/2403.15837 + """ + + x = np.linspace(-1, 1, grid_size) + x, y = np.meshgrid(x, x) + mean, cov_matrix = [mean, mean], [[std, 0], [0, std]] + bivariate_dist = multivariate_normal(mean=mean, cov=cov_matrix) + points = np.column_stack([x.flatten(), y.flatten()]) + # Evaluate the PDF at the given points + pdf_values = bivariate_dist.pdf(points) + + return 1 - pdf_values def forward(self, x): if not self.training or self.prob == 0.: @@ -76,6 +96,8 @@ def forward(self, x): num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) + if self.gaussian_masking: + rand = rand - self.pdf_values patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] @@ -448,6 +470,8 @@ def __init__( attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., + gaussian_masking: bool = False, + gaussian_masking_std: float = 0.20, no_ln_pre: bool = False, pos_embed_type: str = 'learnable', pool_type: str = 'tok', @@ -485,7 +509,8 @@ def __init__( raise ValueError # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn - self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + self.patch_dropout = PatchDropout(patch_dropout, self.grid_size[0], + gaussian_masking, gaussian_masking_std) if patch_dropout > 0. else nn.Identity() self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) self.transformer = Transformer( diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 591ea1d32..cafecab30 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -229,6 +229,8 @@ def main(args): force_quick_gelu=args.force_quick_gelu, force_custom_text=args.force_custom_text, force_patch_dropout=args.force_patch_dropout, + gaussian_masking=args.gaussian_masking, + gaussian_masking_std=args.gaussian_masking_std, force_image_size=args.force_image_size, image_mean=args.image_mean, image_std=args.image_std, diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index c3d19302d..480f0afc8 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -279,6 +279,18 @@ def parse_args(args): type=float, help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", ) + parser.add_argument( + "--gaussian-masking", + default=False, + action='store_true', + help="Use gaussian masking for patch dropout.", + ) + parser.add_argument( + "--gaussian-masking-std", + default=0.20, + type=float, + help="Gaussian masking std for patch dropout.", + ) parser.add_argument( "--force-custom-text", default=False, From b30bba1eb6a700658ab0803e73ef80db4c2032fa Mon Sep 17 00:00:00 2001 From: Anastasiais-ml Date: Thu, 16 Jan 2025 12:39:20 +0100 Subject: [PATCH 2/4] Add GLIP --- src/open_clip/factory.py | 314 +++++++++++++++++++++++++-------------- 1 file changed, 201 insertions(+), 113 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index c29f59610..e9d078334 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -2,6 +2,7 @@ import logging import os import re +import warnings from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -9,13 +10,11 @@ import torch -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .convert import convert_state_dict from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss -from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs @@ -68,14 +67,25 @@ def add_model_config(path): def get_model_config(model_name): + """ Fetch model config from builtin (local library) configs. + """ if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None -def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) +def _get_hf_config( + model_id: str, + cache_dir: Optional[str] = None, +): + """ Fetch model config from HuggingFace Hub. + """ + config_path = download_pretrained_from_hf( + model_id, + filename='open_clip_config.json', + cache_dir=cache_dir, + ) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config @@ -84,16 +94,18 @@ def _get_hf_config(model_id, cache_dir=None): def get_tokenizer( model_name: str = '', context_length: Optional[int] = None, + cache_dir: Optional[str] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): model_name = model_name[len(HF_HUB_PREFIX):] try: - config = _get_hf_config(model_name)['model_cfg'] + config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, context_length=context_length or DEFAULT_CONTEXT_LENGTH, + cache_dir=cache_dir, **kwargs, ) return tokenizer @@ -114,6 +126,7 @@ def get_tokenizer( tokenizer = HFTokenizer( text_config['hf_tokenizer_name'], context_length=context_length, + cache_dir=cache_dir, **tokenizer_kwargs, ) else: @@ -175,6 +188,14 @@ def load_checkpoint( if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + # correct if logit_scale differs in being scaler vs 1d param + if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: + state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) + + # correct if logit_bias differs in being scaler vs 1d param + if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: + state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) + # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) @@ -210,15 +231,65 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + load_weights_only: bool = True, **model_kwargs, ): + """Creates and configures a contrastive vision-language model. + + Args: + model_name: Name of the model architecture to create. Can be a local model name + or a Hugging Face model ID prefixed with 'hf-hub:'. + pretrained: Tag/path for pretrained model weights. Can be: + - A pretrained tag name (e.g., 'openai') + - A path to local weights + - None to initialize with random weights + precision: Model precision/AMP configuration. Options: + - 'fp32': 32-bit floating point + - 'fp16'/'bf16': Mixed precision with FP32 for certain layers + - 'pure_fp16'/'pure_bf16': Pure 16-bit precision + device: Device to load the model on ('cpu', 'cuda', or torch.device object) + jit: If True, JIT compile the model + force_quick_gelu: Force use of QuickGELU activation + force_custom_text: Force use of custom text encoder + force_patch_dropout: Override default patch dropout value + force_image_size: Override default image size for vision encoder + force_preprocess_cfg: Override default preprocessing configuration + pretrained_image: Load pretrained weights for timm vision models + pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights + cache_dir: Override default cache directory for downloaded model files + output_dict: If True and model supports it, return dictionary of features + require_pretrained: Raise error if pretrained weights cannot be loaded + load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) + **model_kwargs: Additional keyword arguments passed to model constructor + + Returns: + Created and configured model instance + + Raises: + RuntimeError: If model config is not found or required pretrained weights + cannot be loaded + + Examples: + # Create basic CLIP model + model = create_model('ViT-B/32') + + # Create CLIP model with mixed precision on GPU + model = create_model('ViT-B/32', precision='fp16', device='cuda') + + # Load pretrained OpenAI weights + model = create_model('ViT-B/32', pretrained='openai') + + # Load Hugging Face model + model = create_model('hf-hub:organization/model-name') + """ + force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config = _get_hf_config(model_id, cache_dir) + config = _get_hf_config(model_id, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights @@ -230,120 +301,121 @@ def create_model( if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == 'openai': - logging.info(f'Loading pretrained {model_name} from OpenAI.') - model = load_openai_model( - model_name, - precision=precision, - device=device, - cache_dir=cache_dir, - ) + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') else: - model_cfg = model_cfg or get_model_config(model_name) - if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') - else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - if force_patch_dropout is not None: - # override the default patch dropout value - model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout - - if gaussian_masking is not None: - # override the default gaussian masking value + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if gaussian_masking is not None: + # override the default masking model_cfg["vision_cfg"]["gaussian_masking"] = gaussian_masking - if gaussian_masking_std is not None: + if gaussian_masking_std is not None: # override the default gaussian masking std value model_cfg["vision_cfg"]["gaussian_masking_std"] = gaussian_masking_std - if force_image_size is not None: - # override model config's image size - model_cfg["vision_cfg"]["image_size"] = force_image_size - - is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) - if pretrained_image: - if is_timm_model: - # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True - else: - assert False, 'pretrained image towers currently only supported for timm models' - - # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes - cast_dtype = get_cast_dtype(precision) - is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) - if is_hf_model: - # load pretrained weights for HF text model IFF no CLIP weights being loaded - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model - - model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) - if custom_text: - if "multimodal_cfg" in model_cfg: - model = CoCa(**model_cfg, cast_dtype=cast_dtype) - else: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) - - if precision in ("fp16", "bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 - # manual mixed precision that matches original OpenAI behaviour - if is_timm_model: - # FIXME this is a bit janky, create timm based model in low-precision and - # then cast only LayerNormFp32 instances back to float32 so they don't break. - # Why? The convert_weights_to_lp fn only works with native models. - model.to(device=device, dtype=dtype) - from .transformer import LayerNormFp32 - - def _convert_ln(m): - if isinstance(m, LayerNormFp32): - m.weight.data = m.weight.data.to(torch.float32) - m.bias.data = m.bias.data.to(torch.float32) - model.apply(_convert_ln) - else: - model.to(device=device) - convert_weights_to_lp(model, dtype=dtype) - elif precision in ("pure_fp16", "pure_bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) else: model.to(device=device) - - pretrained_loaded = False - if pretrained: - checkpoint_path = '' - pretrained_cfg = get_pretrained_cfg(model_name, pretrained) - if pretrained_cfg: - checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) - elif os.path.exists(pretrained): - checkpoint_path = pretrained - - if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) - else: - error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') - logging.warning(error_str) - raise RuntimeError(error_str) - pretrained_loaded = True - elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') - load_checkpoint(model, checkpoint_path) - pretrained_loaded = True - - if require_pretrained and not pretrained_loaded: - # callers of create_model_from_pretrained always expect pretrained weights - raise RuntimeError( - f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) + model_quick_gelu = model_cfg.get('quick_gelu', False) + if pretrained_quick_gelu and not model_quick_gelu: + warnings.warn( + f'These pretrained weights were trained with QuickGELU activation but the model config does ' + f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') + elif not pretrained_quick_gelu and model_quick_gelu: + warnings.warn( + f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' + f'model config, consider using a model config without QuickGELU or disable override flags.') + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True @@ -386,7 +458,9 @@ def create_loss(args): return SigLipLoss( rank=args.rank, world_size=args.world_size, + dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from ) + return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, @@ -418,10 +492,16 @@ def create_model_and_transforms( pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -440,6 +520,7 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + load_weights_only=load_weights_only, **model_kwargs, ) @@ -473,10 +554,16 @@ def create_model_from_pretrained( image_resize_mode: Optional[str] = None, # only effective for inference return_transform: bool = True, cache_dir: Optional[str] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -490,6 +577,7 @@ def create_model_from_pretrained( force_preprocess_cfg=force_preprocess_cfg, cache_dir=cache_dir, require_pretrained=True, + load_weights_only=load_weights_only, **model_kwargs, ) @@ -501,4 +589,4 @@ def create_model_from_pretrained( is_train=False, ) - return model, preprocess + return model, preprocess \ No newline at end of file From c5ec898137c55c1733d0365dba85bc59c25b2bc4 Mon Sep 17 00:00:00 2001 From: Anastasiais-ml Date: Fri, 31 Jan 2025 16:21:55 +0100 Subject: [PATCH 3/4] "update GLIP" --- src/open_clip/factory.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index e9d078334..445171fb2 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -315,6 +315,10 @@ def create_model( if force_patch_dropout is not None: # override the default patch dropout value model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size if gaussian_masking is not None: # override the default masking @@ -324,10 +328,6 @@ def create_model( # override the default gaussian masking std value model_cfg["vision_cfg"]["gaussian_masking_std"] = gaussian_masking_std - if force_image_size is not None: - # override model config's image size - model_cfg["vision_cfg"]["image_size"] = force_image_size - is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: if is_timm_model: From 7b719dd7c6a76ed64741e5bf3d8b5e42cadb36d0 Mon Sep 17 00:00:00 2001 From: Anastasiais-ml Date: Fri, 31 Jan 2025 16:27:10 +0100 Subject: [PATCH 4/4] update GLIP --- src/open_clip/factory.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 445171fb2..dbc5462bd 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -252,6 +252,8 @@ def create_model( force_quick_gelu: Force use of QuickGELU activation force_custom_text: Force use of custom text encoder force_patch_dropout: Override default patch dropout value + gaussian_masking: Enable Gaussian masking for patch dropout + gaussian_masking_std: Set Gaussian masking standard deviation force_image_size: Override default image size for vision encoder force_preprocess_cfg: Override default preprocessing configuration pretrained_image: Load pretrained weights for timm vision models @@ -321,12 +323,12 @@ def create_model( model_cfg["vision_cfg"]["image_size"] = force_image_size if gaussian_masking is not None: - # override the default masking - model_cfg["vision_cfg"]["gaussian_masking"] = gaussian_masking + # override the default masking + model_cfg["vision_cfg"]["gaussian_masking"] = gaussian_masking if gaussian_masking_std is not None: - # override the default gaussian masking std value - model_cfg["vision_cfg"]["gaussian_masking_std"] = gaussian_masking_std + # override the default gaussian masking std value + model_cfg["vision_cfg"]["gaussian_masking_std"] = gaussian_masking_std is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: