Skip to content

Commit

Permalink
Enable HF pretrained backbones (huggingface#31145)
Browse files Browse the repository at this point in the history
* Enable load HF or tim backbone checkpoints

* Fix up

* Fix test - pass in proper out_indices

* Update docs

* Fix tvp tests

* Fix doc examples

* Fix doc examples

* Try to resolve DPT backbone param init

* Don't conditionally set to None

* Add condition based on whether backbone is defined

* Address review comments
  • Loading branch information
amyeroberts authored Jun 6, 2024
1 parent a3d351c commit bdf36dc
Show file tree
Hide file tree
Showing 27 changed files with 546 additions and 69 deletions.
54 changes: 36 additions & 18 deletions docs/source/en/create_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,31 +327,21 @@ For example, to load a [ResNet](../model_doc/resnet) backbone into a [MaskFormer
Set `use_pretrained_backbone=True` to load pretrained ResNet weights for the backbone.

```py
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=True) # backbone and neck config
config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=True) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

You could also load the backbone config separately and then pass it to the model config.

```py
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig

backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
config = MaskFormerConfig(backbone_config=backbone_config)
model = MaskFormerForInstanceSegmentation(config)
```

</hfoption>
<hfoption id="random weights">

Set `use_pretrained_backbone=False` to randomly initialize a ResNet backbone.

```py
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation, ResNetConfig
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=False) # backbone and neck config
config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=False) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

Expand All @@ -366,15 +356,43 @@ model = MaskFormerForInstanceSegmentation(config)
```

</hfoption>
</hfoptions>
</hfoptions id="timm backbone">

[timm](https://hf.co/docs/timm/index) models are loaded within a model with `use_timm_backbone=True` or with [`TimmBackbone`] and [`TimmBackboneConfig`].

[timm](https://hf.co/docs/timm/index) models are loaded with [`TimmBackbone`] and [`TimmBackboneConfig`].
Use `use_timm_backbone=True` and `use_pretrained_backbone=True` to load pretrained timm weights for the backbone.

```python
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=True, use_timm_backbone=True) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

Set `use_timm_backbone=True` and `use_pretrained_backbone=False` to load a randomly initialized timm backbone.

```python
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=False, use_timm_backbone=True) # backbone and neck config
model = MaskFormerForInstanceSegmentation(config) # head
```

You could also load the backbone config and use it to create a `TimmBackbone` or pass it to the model config. Timm backbones will load pretrained weights by default. Set `use_pretrained_backbone=False` to load randomly initialized weights.

```python
from transformers import TimmBackboneConfig, TimmBackbone

backbone_config = TimmBackboneConfig("resnet50")
model = TimmBackbone(config=backbone_config)
backbone_config = TimmBackboneConfig("resnet50", use_pretrained_backbone=False)

# Create a backbone class
backbone = TimmBackbone(config=backbone_config)

# Create a model with a timm backbone
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

config = MaskFormerConfig(backbone_config=backbone_config)
model = MaskFormerForInstanceSegmentation(config)
```

## Feature extractor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,14 @@ def __init__(self, config):
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
else:
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")

if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,14 @@ def __init__(self, config):
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
else:
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")

if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __init__(
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.reassemble_hidden_size = reassemble_hidden_size
self.patch_size = patch_size
self.initializer_range = initializer_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...modeling_outputs import DepthEstimatorOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_depth_anything import DepthAnythingConfig


Expand Down Expand Up @@ -365,9 +365,7 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.backbone = AutoBackbone.from_config(
config.backbone_config, attn_implementation=config._attn_implementation
)
self.backbone = load_backbone(config)
self.neck = DepthAnythingNeck(config)
self.head = DepthAnythingDepthEstimationHead(config)

Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,14 @@ def __init__(self, config):
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
else:
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")

if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
Expand Down
52 changes: 27 additions & 25 deletions src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,16 @@ def __init__(

use_autobackbone = False
if self.is_hybrid:
if backbone_config is None and backbone is None:
logger.info("Initializing the config with a `BiT` backbone.")
if backbone_config is None:
backbone_config = {
"global_padding": "same",
"layer_type": "bottleneck",
"depths": [3, 4, 9],
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
}
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict):

if isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig):
Expand All @@ -208,9 +207,8 @@ def __init__(
if readout_type != "project":
raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")

elif backbone_config is not None:
elif backbone is not None or backbone_config is not None:
use_autobackbone = True

if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
Expand All @@ -219,33 +217,37 @@ def __init__(
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []

# We only use load_backbone when config.is_hydrid is False
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
else:
self.backbone_config = backbone_config
self.backbone_config = None
self.backbone_featmap_shape = None
self.neck_ignore_stages = []

verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)

self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
self.hidden_dropout_prob = None if use_autobackbone else hidden_dropout_prob
self.attention_probs_dropout_prob = None if use_autobackbone else attention_probs_dropout_prob
self.layer_norm_eps = None if use_autobackbone else layer_norm_eps
self.image_size = None if use_autobackbone else image_size
self.patch_size = None if use_autobackbone else patch_size
self.num_channels = None if use_autobackbone else num_channels
self.qkv_bias = None if use_autobackbone else qkv_bias

# ViT parameters used if not using a hybrid backbone
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.use_autobackbone = use_autobackbone
self.backbone_out_indices = None if use_autobackbone else backbone_out_indices

if readout_type not in ["ignore", "add", "project"]:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,10 +1071,10 @@ def __init__(self, config):
super().__init__(config)

self.backbone = None
if config.is_hybrid or config.backbone_config is None:
self.dpt = DPTModel(config, add_pooling_layer=False)
else:
if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None):
self.backbone = load_backbone(config)
else:
self.dpt = DPTModel(config, add_pooling_layer=False)

# Neck
self.neck = DPTNeck(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,14 @@ def __init__(self, config):
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
else:
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")

if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,14 @@ def __init__(self, config):
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
backbone_model_type = None
if config.backbone is not None:
backbone_model_type = config.backbone
elif config.backbone_config is not None:
backbone_model_type = config.backbone_config.model_type
else:
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")

if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(self, config, **kwargs):
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")

if config.backbone not in timm.list_models():
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
# Certain timm models have the structure `model_name.version` e.g. vit_large_patch14_dinov2.lvd142m
base_backbone_model = config.backbone.split(".")[0]
if base_backbone_model not in timm.list_models():
raise ValueError(f"backbone {base_backbone_model} is not supported by timm.")

if hasattr(config, "out_features") and config.out_features is not None:
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/tvp/modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,18 @@ class TvpVisionModel(nn.Module):
def __init__(self, config):
super().__init__()
self.backbone = load_backbone(config)

if config.backbone_config is not None:
in_channels = config.backbone_config.hidden_sizes[-1]
elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"):
in_channels = self.backbone.config.hidden_sizes[-1]
elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"):
in_channels = self.backbone.config.hidden_size
else:
raise ValueError("Backbone config not found")

self.grid_encoder_conv = nn.Conv2d(
config.backbone_config.hidden_sizes[-1],
in_channels,
config.hidden_size,
kernel_size=3,
stride=1,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/vitmatte/modeling_vitmatte.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ class VitMatteConvStream(nn.Module):
def __init__(self, config):
super().__init__()

in_channels = config.backbone_config.num_channels
# We use a default in-case there isn't a backbone config set. This is for backwards compatibility and
# to enable loading HF backbone models.
in_channels = 4
if config.backbone_config is not None:
in_channels = config.backbone_config.num_channels

out_channels = config.convstream_hidden_sizes

self.convs = nn.ModuleList()
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/utils/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,6 @@ def verify_backbone_config_arguments(
"""
Verify that the config arguments to be passed to load_backbone are valid
"""
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

Expand Down
36 changes: 36 additions & 0 deletions tests/models/conditional_detr/test_modeling_conditional_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,42 @@ def test_different_timm_backbone(self):

self.assertTrue(outputs)

@require_timm
def test_hf_backbone(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Load a pretrained HF checkpoint as backbone
config.backbone = "microsoft/resnet-18"
config.backbone_config = None
config.use_timm_backbone = False
config.use_pretrained_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))

if model_class.__name__ == "ConditionalDetrForObjectDetection":
expected_shape = (
self.model_tester.batch_size,
self.model_tester.num_queries,
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
Loading

0 comments on commit bdf36dc

Please sign in to comment.