|  | 
|  | 1 | +from contextlib import contextmanager, nullcontext | 
| 1 | 2 | import dataclasses | 
| 2 | 3 | import logging | 
| 3 | 4 | import os | 
| 4 | 5 | from copy import deepcopy | 
| 5 | 6 | from pathlib import Path | 
| 6 |  | -from typing import Any, Callable, Dict, List, Optional, Tuple, Union | 
|  | 7 | +from typing import Any, Callable, Dict, Optional, Tuple, Union | 
| 7 | 8 | 
 | 
|  | 9 | +import torch | 
| 8 | 10 | from torch import nn as nn | 
| 9 | 11 | from torch.hub import load_state_dict_from_url | 
| 10 | 12 | 
 | 
| @@ -360,6 +362,27 @@ def resolve_pretrained_cfg( | 
| 360 | 362 |     return pretrained_cfg | 
| 361 | 363 | 
 | 
| 362 | 364 | 
 | 
|  | 365 | +@contextmanager | 
|  | 366 | +def make_meta_init(*classes): | 
|  | 367 | +    def create_new_init(cls): | 
|  | 368 | +        old_init = cls.__init__ | 
|  | 369 | +        def new_init(self, *args, **kwargs): | 
|  | 370 | +            kwargs.update(device="meta") | 
|  | 371 | +            old_init(self, *args, **kwargs) | 
|  | 372 | +        return new_init | 
|  | 373 | + | 
|  | 374 | +    original_dict = dict() | 
|  | 375 | +    for cls in classes: | 
|  | 376 | +        original_dict[cls] = cls.__init__ | 
|  | 377 | +        cls.__init__ = create_new_init(cls) | 
|  | 378 | + | 
|  | 379 | +    yield | 
|  | 380 | + | 
|  | 381 | +    # restore original __init__() | 
|  | 382 | +    for cls, old_init in original_dict.items(): | 
|  | 383 | +        cls.__init__ = old_init | 
|  | 384 | + | 
|  | 385 | + | 
| 363 | 386 | def build_model_with_cfg( | 
| 364 | 387 |         model_cls: Callable, | 
| 365 | 388 |         variant: str, | 
| @@ -419,11 +442,27 @@ def build_model_with_cfg( | 
| 419 | 442 |         if 'feature_cls' in kwargs: | 
| 420 | 443 |             feature_cfg['feature_cls'] = kwargs.pop('feature_cls') | 
| 421 | 444 | 
 | 
|  | 445 | +    # use meta-device init to speed up loading pretrained weights. | 
|  | 446 | +    # when num_classes is changed, we can't use meta device init since we need | 
|  | 447 | +    # the original __init__() to initialize head from scratch. | 
|  | 448 | +    num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"]) | 
|  | 449 | +    use_meta_init = ( | 
|  | 450 | +        pretrained | 
|  | 451 | +        and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"]) | 
|  | 452 | +    ) | 
|  | 453 | + | 
| 422 | 454 |     # Instantiate the model | 
| 423 |  | -    if model_cfg is None: | 
| 424 |  | -        model = model_cls(**kwargs) | 
| 425 |  | -    else: | 
| 426 |  | -        model = model_cls(cfg=model_cfg, **kwargs) | 
|  | 455 | +    base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm] | 
|  | 456 | +    with make_meta_init(*base_classes) if use_meta_init else nullcontext(): | 
|  | 457 | +        if model_cfg is None: | 
|  | 458 | +            model = model_cls(**kwargs) | 
|  | 459 | +        else: | 
|  | 460 | +            model = model_cls(cfg=model_cfg, **kwargs) | 
|  | 461 | + | 
|  | 462 | +    # convert meta-device tensors to concrete tensors | 
|  | 463 | +    device = kwargs.get("device", torch.get_default_device()) | 
|  | 464 | +    model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t)) | 
|  | 465 | + | 
| 427 | 466 |     model.pretrained_cfg = pretrained_cfg | 
| 428 | 467 |     model.default_cfg = model.pretrained_cfg  # alias for backwards compat | 
| 429 | 468 | 
 | 
|  | 
0 commit comments