diff --git a/test/smoke_test.py b/test/smoke_test.py index 647e2f45c8f..3a44ae3efe9 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -102,9 +102,8 @@ def main() -> None: smoke_test_torchvision_resnet50_classify("cuda") # TODO: remove once pytorch/pytorch#110436 is resolved - # Temporary Disabling compile test. Untill triton with Manylinux2014 is available - # if sys.version_info < (3, 12, 0): - # smoke_test_compile() + if sys.version_info < (3, 12, 0): + smoke_test_compile() if torch.backends.mps.is_available(): smoke_test_torchvision_resnet50_classify("mps") diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index befceca020e..c64e27f14ac 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,4 +1,5 @@ import random +from copy import deepcopy from itertools import chain from typing import Mapping, Sequence @@ -322,3 +323,14 @@ def forward(self, x): out = model(self.inp) # And backward out["leaf_module"].float().mean().backward() + + def test_deepcopy(self): + # Non-regression test for https://github.com/pytorch/vision/issues/8634 + model = models.efficientnet_b3(weights=None) + extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"}) + + extractor.eval() + extractor.train() + extractor = deepcopy(extractor) + extractor.eval() + extractor.train() diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index f42bc124c7b..a20387a6b89 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -1,3 +1,4 @@ +import copy import inspect import math import re @@ -10,7 +11,7 @@ import torch import torchvision from torch import fx, nn -from torch.fx.graph_module import _copy_attr +from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY __all__ = ["create_feature_extractor", "get_graph_node_names"] @@ -330,6 +331,40 @@ def train(self, mode=True): self.graph = self.eval_graph return super().train(mode=mode) + def _deepcopy_init(self): + # See __deepcopy__ below + return DualGraphModule.__init__ + + def __deepcopy__(self, memo): + # Same as the base class' __deepcopy__ from pytorch, with minor + # modification to account for train_graph and eval_graph + # https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875 + # + # This is using a bunch of private stuff from torch, so if that breaks, + # we'll likely have to remove this, along with the associated + # non-regression test. + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"]) + + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hook", + "_create_node_hooks", + "_erase_node_hooks", + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + return res + def create_feature_extractor( model: nn.Module,