diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 4d22af150..7b96b450e 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import copy from abc import abstractmethod from functools import partial -from typing import Tuple, Any, Dict, List, Union, Callable +from typing import Tuple, Any, Dict, List, Callable import torch import numpy as np @@ -30,7 +31,6 @@ from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER @@ -224,8 +224,10 @@ def __init__(self, """ super(PytorchModel, self).__init__() - self.graph = graph - self.node_sort = list(topological_sort(graph)) + self.graph = copy.deepcopy(graph) + delattr(self.graph, 'tpc') + + self.node_sort = list(topological_sort(self.graph)) self.node_to_activation_quantization_holder = {} self.append2output = append2output self.return_float_outputs = return_float_outputs