Skip to content

Commit

Permalink
2024-11-29 nightly release (8f8a195)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 29, 2024
1 parent d3df9e2 commit 98116a6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
5 changes: 2 additions & 3 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def main() -> None:
smoke_test_torchvision_resnet50_classify("cuda")

# TODO: remove once pytorch/pytorch#110436 is resolved
# Temporary Disabling compile test. Untill triton with Manylinux2014 is available
# if sys.version_info < (3, 12, 0):
# smoke_test_compile()
if sys.version_info < (3, 12, 0):
smoke_test_compile()

if torch.backends.mps.is_available():
smoke_test_torchvision_resnet50_classify("mps")
Expand Down
12 changes: 12 additions & 0 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from copy import deepcopy
from itertools import chain
from typing import Mapping, Sequence

Expand Down Expand Up @@ -322,3 +323,14 @@ def forward(self, x):
out = model(self.inp)
# And backward
out["leaf_module"].float().mean().backward()

def test_deepcopy(self):
# Non-regression test for https://github.com/pytorch/vision/issues/8634
model = models.efficientnet_b3(weights=None)
extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})

extractor.eval()
extractor.train()
extractor = deepcopy(extractor)
extractor.eval()
extractor.train()
37 changes: 36 additions & 1 deletion torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import math
import re
Expand All @@ -10,7 +11,7 @@
import torch
import torchvision
from torch import fx, nn
from torch.fx.graph_module import _copy_attr
from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY


__all__ = ["create_feature_extractor", "get_graph_node_names"]
Expand Down Expand Up @@ -330,6 +331,40 @@ def train(self, mode=True):
self.graph = self.eval_graph
return super().train(mode=mode)

def _deepcopy_init(self):
# See __deepcopy__ below
return DualGraphModule.__init__

def __deepcopy__(self, memo):
# Same as the base class' __deepcopy__ from pytorch, with minor
# modification to account for train_graph and eval_graph
# https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875
#
# This is using a bunch of private stuff from torch, so if that breaks,
# we'll likely have to remove this, along with the associated
# non-regression test.
res = type(self).__new__(type(self))
memo[id(self)] = res
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"])

extra_preserved_attrs = [
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
"_replace_hook",
"_create_node_hooks",
"_erase_node_hooks",
]
for attr in extra_preserved_attrs:
if attr in self.__dict__:
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
setattr(res, attr_name, attr)
return res


def create_feature_extractor(
model: nn.Module,
Expand Down

0 comments on commit 98116a6

Please sign in to comment.