From bdf36dcd48106a4a0278ed7f3cc26cd65ab7b066 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 6 Jun 2024 22:02:38 +0100 Subject: [PATCH] Enable HF pretrained backbones (#31145) * Enable load HF or tim backbone checkpoints * Fix up * Fix test - pass in proper out_indices * Update docs * Fix tvp tests * Fix doc examples * Fix doc examples * Try to resolve DPT backbone param init * Don't conditionally set to None * Add condition based on whether backbone is defined * Address review comments --- docs/source/en/create_a_model.md | 54 ++++++++++++------- .../modeling_conditional_detr.py | 9 +++- .../modeling_deformable_detr.py | 9 +++- .../configuration_depth_anything.py | 2 + .../depth_anything/modeling_depth_anything.py | 6 +-- src/transformers/models/detr/modeling_detr.py | 9 +++- .../models/dpt/configuration_dpt.py | 52 +++++++++--------- src/transformers/models/dpt/modeling_dpt.py | 6 +-- .../grounding_dino/modeling_grounding_dino.py | 9 +++- .../modeling_table_transformer.py | 9 +++- .../timm_backbone/modeling_timm_backbone.py | 6 ++- src/transformers/models/tvp/modeling_tvp.py | 12 ++++- .../models/vitmatte/modeling_vitmatte.py | 7 ++- src/transformers/utils/backbone_utils.py | 5 -- .../test_modeling_conditional_detr.py | 36 +++++++++++++ .../test_modeling_deformable_detr.py | 33 +++++++++++- .../test_modeling_depth_anything.py | 29 ++++++++++ tests/models/detr/test_modeling_detr.py | 35 ++++++++++++ tests/models/dpt/test_modeling_dpt.py | 28 ++++++++++ .../test_modeling_grounding_dino.py | 28 ++++++++++ .../mask2former/test_modeling_mask2former.py | 32 +++++++++++ .../maskformer/test_modeling_maskformer.py | 32 +++++++++++ .../oneformer/test_modeling_oneformer.py | 32 +++++++++++ .../test_modeling_table_transformer.py | 32 +++++++++++ tests/models/tvp/test_modeling_tvp.py | 37 ++++++++++++- tests/models/upernet/test_modeling_upernet.py | 36 ++++++++++++- .../models/vitmatte/test_modeling_vitmatte.py | 30 +++++++++++ 27 files changed, 546 insertions(+), 69 deletions(-) diff --git a/docs/source/en/create_a_model.md b/docs/source/en/create_a_model.md index 29f26c59984aa3..0ecc503df61533 100644 --- a/docs/source/en/create_a_model.md +++ b/docs/source/en/create_a_model.md @@ -327,31 +327,21 @@ For example, to load a [ResNet](../model_doc/resnet) backbone into a [MaskFormer Set `use_pretrained_backbone=True` to load pretrained ResNet weights for the backbone. ```py -from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig +from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation -config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=True) # backbone and neck config +config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=True) # backbone and neck config model = MaskFormerForInstanceSegmentation(config) # head ``` -You could also load the backbone config separately and then pass it to the model config. - -```py -from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig - -backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50") -config = MaskFormerConfig(backbone_config=backbone_config) -model = MaskFormerForInstanceSegmentation(config) -``` - Set `use_pretrained_backbone=False` to randomly initialize a ResNet backbone. ```py -from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig +from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation -config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=False) # backbone and neck config +config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=False) # backbone and neck config model = MaskFormerForInstanceSegmentation(config) # head ``` @@ -366,15 +356,43 @@ model = MaskFormerForInstanceSegmentation(config) ``` - + + +[timm](https://hf.co/docs/timm/index) models are loaded within a model with `use_timm_backbone=True` or with [`TimmBackbone`] and [`TimmBackboneConfig`]. -[timm](https://hf.co/docs/timm/index) models are loaded with [`TimmBackbone`] and [`TimmBackboneConfig`]. +Use `use_timm_backbone=True` and `use_pretrained_backbone=True` to load pretrained timm weights for the backbone. + +```python +from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation + +config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=True, use_timm_backbone=True) # backbone and neck config +model = MaskFormerForInstanceSegmentation(config) # head +``` + +Set `use_timm_backbone=True` and `use_pretrained_backbone=False` to load a randomly initialized timm backbone. + +```python +from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation + +config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=False, use_timm_backbone=True) # backbone and neck config +model = MaskFormerForInstanceSegmentation(config) # head +``` + +You could also load the backbone config and use it to create a `TimmBackbone` or pass it to the model config. Timm backbones will load pretrained weights by default. Set `use_pretrained_backbone=False` to load randomly initialized weights. ```python from transformers import TimmBackboneConfig, TimmBackbone -backbone_config = TimmBackboneConfig("resnet50") -model = TimmBackbone(config=backbone_config) +backbone_config = TimmBackboneConfig("resnet50", use_pretrained_backbone=False) + +# Create a backbone class +backbone = TimmBackbone(config=backbone_config) + +# Create a model with a timm backbone +from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation + +config = MaskFormerConfig(backbone_config=backbone_config) +model = MaskFormerForInstanceSegmentation(config) ``` ## Feature extractor diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index aa905d9e960ae9..e72daa64713e8c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -378,7 +378,14 @@ def __init__(self, config): self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels ) - backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if config.use_timm_backbone: diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index f619575bd81452..4920262443035d 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -449,7 +449,14 @@ def __init__(self, config): self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels ) - backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if config.use_timm_backbone: diff --git a/src/transformers/models/depth_anything/configuration_depth_anything.py b/src/transformers/models/depth_anything/configuration_depth_anything.py index 77727e65a0bfe3..78ccbc381dc21d 100644 --- a/src/transformers/models/depth_anything/configuration_depth_anything.py +++ b/src/transformers/models/depth_anything/configuration_depth_anything.py @@ -129,6 +129,8 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs self.reassemble_hidden_size = reassemble_hidden_size self.patch_size = patch_size self.initializer_range = initializer_range diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 493d59bd4439e2..0b1ef77c6a732a 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -28,7 +28,7 @@ from ...modeling_outputs import DepthEstimatorOutput from ...modeling_utils import PreTrainedModel from ...utils import logging -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_depth_anything import DepthAnythingConfig @@ -365,9 +365,7 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = AutoBackbone.from_config( - config.backbone_config, attn_implementation=config._attn_implementation - ) + self.backbone = load_backbone(config) self.neck = DepthAnythingNeck(config) self.head = DepthAnythingDepthEstimationHead(config) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index ff8b1416b06770..447f8a807fcb66 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -373,7 +373,14 @@ def __init__(self, config): self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels ) - backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if config.use_timm_backbone: diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index 308f2647ee68d8..869f384f56985e 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -182,8 +182,7 @@ def __init__( use_autobackbone = False if self.is_hybrid: - if backbone_config is None and backbone is None: - logger.info("Initializing the config with a `BiT` backbone.") + if backbone_config is None: backbone_config = { "global_padding": "same", "layer_type": "bottleneck", @@ -191,8 +190,8 @@ def __init__( "out_features": ["stage1", "stage2", "stage3"], "embedding_dynamic_padding": True, } - backbone_config = BitConfig(**backbone_config) - elif isinstance(backbone_config, dict): + + if isinstance(backbone_config, dict): logger.info("Initializing the config with a `BiT` backbone.") backbone_config = BitConfig(**backbone_config) elif isinstance(backbone_config, PretrainedConfig): @@ -208,9 +207,8 @@ def __init__( if readout_type != "project": raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.") - elif backbone_config is not None: + elif backbone is not None or backbone_config is not None: use_autobackbone = True - if isinstance(backbone_config, dict): backbone_model_type = backbone_config.get("model_type") config_class = CONFIG_MAPPING[backbone_model_type] @@ -219,33 +217,37 @@ def __init__( self.backbone_config = backbone_config self.backbone_featmap_shape = None self.neck_ignore_stages = [] + + # We only use load_backbone when config.is_hydrid is False + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) else: - self.backbone_config = backbone_config + self.backbone_config = None self.backbone_featmap_shape = None self.neck_ignore_stages = [] - verify_backbone_config_arguments( - use_timm_backbone=use_timm_backbone, - use_pretrained_backbone=use_pretrained_backbone, - backbone=backbone, - backbone_config=backbone_config, - backbone_kwargs=backbone_kwargs, - ) - self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone self.use_timm_backbone = use_timm_backbone self.backbone_kwargs = backbone_kwargs - self.num_hidden_layers = None if use_autobackbone else num_hidden_layers - self.num_attention_heads = None if use_autobackbone else num_attention_heads - self.intermediate_size = None if use_autobackbone else intermediate_size - self.hidden_dropout_prob = None if use_autobackbone else hidden_dropout_prob - self.attention_probs_dropout_prob = None if use_autobackbone else attention_probs_dropout_prob - self.layer_norm_eps = None if use_autobackbone else layer_norm_eps - self.image_size = None if use_autobackbone else image_size - self.patch_size = None if use_autobackbone else patch_size - self.num_channels = None if use_autobackbone else num_channels - self.qkv_bias = None if use_autobackbone else qkv_bias + + # ViT parameters used if not using a hybrid backbone + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.use_autobackbone = use_autobackbone self.backbone_out_indices = None if use_autobackbone else backbone_out_indices if readout_type not in ["ignore", "add", "project"]: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index fce944f7a7dea3..a7e554742f2de2 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -1071,10 +1071,10 @@ def __init__(self, config): super().__init__(config) self.backbone = None - if config.is_hybrid or config.backbone_config is None: - self.dpt = DPTModel(config, add_pooling_layer=False) - else: + if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None): self.backbone = load_backbone(config) + else: + self.dpt = DPTModel(config, add_pooling_layer=False) # Neck self.neck = DPTNeck(config) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 1afe3ad44c4ace..dcdccc50cc116d 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -463,7 +463,14 @@ def __init__(self, config): self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels ) - backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if config.use_timm_backbone: diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index a8c2593752d36e..1ebb6cd53bdc87 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -305,7 +305,14 @@ def __init__(self, config): self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels ) - backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type + backbone_model_type = None + if config.backbone is not None: + backbone_model_type = config.backbone + elif config.backbone_config is not None: + backbone_model_type = config.backbone_config.model_type + else: + raise ValueError("Either `backbone` or `backbone_config` should be provided in the config") + if "resnet" in backbone_model_type: for name, parameter in self.model.named_parameters(): if config.use_timm_backbone: diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index e8e0b28e042d6f..74e7388b7dcab5 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -50,8 +50,10 @@ def __init__(self, config, **kwargs): if config.backbone is None: raise ValueError("backbone is not set in the config. Please set it to a timm model name.") - if config.backbone not in timm.list_models(): - raise ValueError(f"backbone {config.backbone} is not supported by timm.") + # Certain timm models have the structure `model_name.version` e.g. vit_large_patch14_dinov2.lvd142m + base_backbone_model = config.backbone.split(".")[0] + if base_backbone_model not in timm.list_models(): + raise ValueError(f"backbone {base_backbone_model} is not supported by timm.") if hasattr(config, "out_features") and config.out_features is not None: raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.") diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 121cd7b5f3f375..ba9acdbbcf93f7 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -143,8 +143,18 @@ class TvpVisionModel(nn.Module): def __init__(self, config): super().__init__() self.backbone = load_backbone(config) + + if config.backbone_config is not None: + in_channels = config.backbone_config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"): + in_channels = self.backbone.config.hidden_sizes[-1] + elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"): + in_channels = self.backbone.config.hidden_size + else: + raise ValueError("Backbone config not found") + self.grid_encoder_conv = nn.Conv2d( - config.backbone_config.hidden_sizes[-1], + in_channels, config.hidden_size, kernel_size=3, stride=1, diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 7d47601b667d34..fb18ed6e789c2e 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -115,7 +115,12 @@ class VitMatteConvStream(nn.Module): def __init__(self, config): super().__init__() - in_channels = config.backbone_config.num_channels + # We use a default in-case there isn't a backbone config set. This is for backwards compatibility and + # to enable loading HF backbone models. + in_channels = 4 + if config.backbone_config is not None: + in_channels = config.backbone_config.num_channels + out_channels = config.convstream_hidden_sizes self.convs = nn.ModuleList() diff --git a/src/transformers/utils/backbone_utils.py b/src/transformers/utils/backbone_utils.py index 5dbd8f3a10dec4..484c4e56c111d0 100644 --- a/src/transformers/utils/backbone_utils.py +++ b/src/transformers/utils/backbone_utils.py @@ -368,11 +368,6 @@ def verify_backbone_config_arguments( """ Verify that the config arguments to be passed to load_backbone are valid """ - if not use_timm_backbone and use_pretrained_backbone: - raise ValueError( - "Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`" - ) - if backbone_config is not None and backbone is not None: raise ValueError("You can't specify both `backbone` and `backbone_config`.") diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index 6fea53fc667a95..9efde402b0760f 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -476,6 +476,42 @@ def test_different_timm_backbone(self): self.assertTrue(outputs) + @require_timm + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "ConditionalDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + elif model_class.__name__ == "ConditionalDetrForSegmentation": + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 2ae3e3f088c29e..b1fc9b23b3c20c 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -544,9 +544,38 @@ def test_different_timm_backbone(self): self.assertEqual(outputs.logits.shape, expected_shape) # Confirm out_indices was propogated to backbone self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4) - elif model_class.__name__ == "ConditionalDetrForSegmentation": + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4) + + self.assertTrue(outputs) + + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "DeformableDetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) # Confirm out_indices was propogated to backbone - self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4) + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4) else: # Confirm out_indices was propogated to backbone self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4) diff --git a/tests/models/depth_anything/test_modeling_depth_anything.py b/tests/models/depth_anything/test_modeling_depth_anything.py index ef1326520aedab..7a08ecb85e7552 100644 --- a/tests/models/depth_anything/test_modeling_depth_anything.py +++ b/tests/models/depth_anything/test_modeling_depth_anything.py @@ -207,6 +207,35 @@ def test_model_from_pretrained(self): model = DepthAnythingForDepthEstimation.from_pretrained(model_name) self.assertIsNotNone(model) + def test_backbone_selection(self): + def _validate_backbone_init(): + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # Confirm out_indices propogated to backbone + self.assertEqual(len(model.backbone.out_indices), 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a timm backbone + config.backbone = "resnet18" + config.use_pretrained_backbone = True + config.use_timm_backbone = True + config.backbone_config = None + # For transformer backbones we can't set the out_indices or just return the features + config.backbone_kwargs = {"out_indices": (-2, -1)} + _validate_backbone_init() + + # Load a HF backbone + config.backbone = "facebook/dinov2-small" + config.use_pretrained_backbone = True + config.use_timm_backbone = False + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [-2, -1]} + _validate_backbone_init() + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index f6277cced35f99..0aea506b5fb8c2 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -476,6 +476,41 @@ def test_different_timm_backbone(self): self.assertTrue(outputs) + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "DetrForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels + 1, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + elif model_class.__name__ == "DetrForSegmentation": + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + def test_greyscale_images(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 8c6231bc1c41d1..db35483d962e3c 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -276,6 +276,34 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + def test_backbone_selection(self): + def _validate_backbone_init(): + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + if model.__class__.__name__ == "DPTForDepthEstimation": + # Confirm out_indices propogated to backbone + self.assertEqual(len(model.backbone.out_indices), 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_pretrained_backbone = True + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [-2, -1]} + # Force load_backbone path + config.is_hybrid = False + + # Load a timm backbone + config.backbone = "resnet18" + config.use_timm_backbone = True + _validate_backbone_init() + + # Load a HF backbone + config.backbone = "facebook/dinov2-small" + config.use_timm_backbone = False + _validate_backbone_init() + @slow def test_model_from_pretrained(self): model_name = "Intel/dpt-large" diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py index 12f80260cb3d3d..fe67deb9383cf8 100644 --- a/tests/models/grounding_dino/test_modeling_grounding_dino.py +++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py @@ -501,6 +501,34 @@ def test_different_timm_backbone(self): self.assertTrue(outputs) + @require_timm + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "GroundingDinoForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + config.max_text_len, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + + self.assertTrue(outputs) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index 100cbafa05380d..1065607e0be8e4 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -21,6 +21,7 @@ from tests.test_modeling_common import floats_tensor from transformers import Mask2FormerConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( + require_timm, require_torch, require_torch_accelerator, require_torch_fp16, @@ -317,6 +318,37 @@ def test_retain_grad_hidden_states_attentions(self): self.assertIsNotNone(transformer_decoder_hidden_states.grad) self.assertIsNotNone(attentions.grad) + @require_timm + def test_backbone_selection(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [1, 2, 3]} + config.use_pretrained_backbone = True + + # Load a timm backbone + # We can't load transformer checkpoint with timm backbone, as we can't specify features_only and out_indices + config.backbone = "resnet18" + config.use_timm_backbone = True + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "Mask2FormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + + # Load a HF backbone + config.backbone = "microsoft/resnet-18" + config.use_timm_backbone = False + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "Mask2FormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + TOLERANCE = 1e-4 diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index 7b2bec17f457bf..4c9c69ed5ff75f 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -22,6 +22,7 @@ from tests.test_modeling_common import floats_tensor from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( + require_timm, require_torch, require_torch_accelerator, require_torch_fp16, @@ -444,6 +445,37 @@ def recursive_check(batched_object, single_row_object, model_name, key): continue recursive_check(model_batched_output[key], model_row_output[key], model_name, key) + @require_timm + def test_backbone_selection(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [1, 2, 3]} + config.use_pretrained_backbone = True + + # Load a timm backbone + # We can't load transformer checkpoint with timm backbone, as we can't specify features_only and out_indices + config.backbone = "resnet18" + config.use_timm_backbone = True + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "MaskFormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "MaskFormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + + # Load a HF backbone + config.backbone = "microsoft/resnet-18" + config.use_timm_backbone = False + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "MaskFormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "MaskFormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + TOLERANCE = 1e-4 diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py index 9cdc475faec91e..b5bb55cb48a9de 100644 --- a/tests/models/oneformer/test_modeling_oneformer.py +++ b/tests/models/oneformer/test_modeling_oneformer.py @@ -23,6 +23,7 @@ from tests.test_modeling_common import floats_tensor from transformers import OneFormerConfig, is_torch_available, is_vision_available from transformers.testing_utils import ( + require_timm, require_torch, require_torch_accelerator, require_torch_fp16, @@ -446,6 +447,37 @@ def test_retain_grad_hidden_states_attentions(self): self.assertIsNotNone(transformer_decoder_mask_predictions.grad) self.assertIsNotNone(attentions.grad) + @require_timm + def test_backbone_selection(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [1, 2, 3]} + config.use_pretrained_backbone = True + + # Load a timm backbone + # We can't load transformer checkpoint with timm backbone, as we can't specify features_only and out_indices + config.backbone = "resnet18" + config.use_timm_backbone = True + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "OneFormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "OneFormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + + # Load a HF backbone + config.backbone = "microsoft/resnet-18" + config.use_timm_backbone = False + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "OneFormerModel": + self.assertEqual(model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + elif model.__class__.__name__ == "OneFormerForUniversalSegmentation": + self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3]) + TOLERANCE = 1e-4 diff --git a/tests/models/table_transformer/test_modeling_table_transformer.py b/tests/models/table_transformer/test_modeling_table_transformer.py index e41b53a21f4e94..f6cef9e8fe4033 100644 --- a/tests/models/table_transformer/test_modeling_table_transformer.py +++ b/tests/models/table_transformer/test_modeling_table_transformer.py @@ -485,6 +485,38 @@ def test_different_timm_backbone(self): self.assertTrue(outputs) + def test_hf_backbone(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Load a pretrained HF checkpoint as backbone + config.backbone = "microsoft/resnet-18" + config.backbone_config = None + config.use_timm_backbone = False + config.use_pretrained_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "TableTransformerForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels + 1, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + def test_greyscale_images(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/tvp/test_modeling_tvp.py b/tests/models/tvp/test_modeling_tvp.py index 90050c3bdfd4b5..a7db94c1a48969 100644 --- a/tests/models/tvp/test_modeling_tvp.py +++ b/tests/models/tvp/test_modeling_tvp.py @@ -16,8 +16,8 @@ import unittest -from transformers import ResNetConfig, TvpConfig -from transformers.testing_utils import require_torch, require_vision, torch_device +from transformers import ResNetConfig, TimmBackboneConfig, TvpConfig +from transformers.testing_utils import require_timm, require_torch, require_vision, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_modeling_common import ( @@ -211,6 +211,39 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @require_timm + def test_backbone_selection(self): + def _validate_backbone_init(): + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # Confirm out_indices propogated to backbone + if model.__class__.__name__ == "TvpModel": + self.assertEqual(len(model.vision_model.backbone.out_indices), 2) + elif model.__class__.__name__ == "TvpForVideoGrounding": + self.assertEqual(len(model.model.vision_model.backbone.out_indices), 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # Force load_backbone path + config.is_hybrid = False + + # We load through configs, as the modeling file assumes config.backbone_config is always set + config.use_pretrained_backbone = False + config.backbone_kwargs = None + + # Load a timm backbone + # We hack adding hidden_sizes to the config to test the backbone loading + backbone_config = TimmBackboneConfig("resnet18", out_indices=[-2, -1], hidden_sizes=[64, 128]) + config.backbone_config = backbone_config + _validate_backbone_init() + + # Load a HF backbone + backbone_config = ResNetConfig.from_pretrained("facebook/dinov2-small", out_indices=[-2, -1]) + config.backbone_config = backbone_config + _validate_backbone_init() + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/upernet/test_modeling_upernet.py b/tests/models/upernet/test_modeling_upernet.py index 79fda279fa7758..820e82acbf342d 100644 --- a/tests/models/upernet/test_modeling_upernet.py +++ b/tests/models/upernet/test_modeling_upernet.py @@ -19,7 +19,14 @@ from huggingface_hub import hf_hub_download from transformers import ConvNextConfig, UperNetConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_timm, + require_torch, + require_torch_multi_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -240,6 +247,33 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @require_timm + def test_backbone_selection(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [1, 2, 3]} + config.use_pretrained_backbone = True + + # Load a timm backbone + # We can't load transformer checkpoint with timm backbone, as we can't specify features_only and out_indices + config.backbone = "resnet18" + config.use_timm_backbone = True + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "UperNetForUniversalSegmentation": + self.assertEqual(model.backbone.out_indices, [1, 2, 3]) + + # Load a HF backbone + config.backbone = "microsoft/resnet-18" + config.use_timm_backbone = False + + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + if model.__class__.__name__ == "UperNetForUniversalSegmentation": + self.assertEqual(model.backbone.out_indices, [1, 2, 3]) + @unittest.skip(reason="UperNet does not have tied weights") def test_tied_model_weights_key_ignore(self): pass diff --git a/tests/models/vitmatte/test_modeling_vitmatte.py b/tests/models/vitmatte/test_modeling_vitmatte.py index ccdefe957c7401..07be1edeb6325a 100644 --- a/tests/models/vitmatte/test_modeling_vitmatte.py +++ b/tests/models/vitmatte/test_modeling_vitmatte.py @@ -20,6 +20,7 @@ from transformers import VitMatteConfig from transformers.testing_utils import ( + require_timm, require_torch, slow, torch_device, @@ -236,6 +237,35 @@ def check_hidden_states_output(inputs_dict, config, model_class): check_hidden_states_output(inputs_dict, config, model_class) + @require_timm + def test_backbone_selection(self): + def _validate_backbone_init(): + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + if model.__class__.__name__ == "VitMatteForImageMatting": + # Confirm out_indices propogated to backbone + self.assertEqual(len(model.backbone.out_indices), 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_pretrained_backbone = True + config.backbone_config = None + config.backbone_kwargs = {"out_indices": [-2, -1]} + # Force load_backbone path + config.is_hybrid = False + + # Load a timm backbone + config.backbone = "resnet18" + config.use_timm_backbone = True + _validate_backbone_init() + + # Load a HF backbone + config.backbone = "facebook/dinov2-small" + config.use_timm_backbone = False + _validate_backbone_init() + @require_torch class VitMatteModelIntegrationTest(unittest.TestCase):