diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/cg_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/cg_utils.py index 357a1eb6d99..8c5f43f44ad 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/cg_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/cg_utils.py @@ -37,7 +37,7 @@ # ============================================================================= """Utilities to traverse model graph""" -from typing import Dict, Optional, Generator +from typing import Dict, Optional, Generator, Tuple from dataclasses import dataclass import functools @@ -61,7 +61,7 @@ class ConnectedGraphTraverser: def __init__(self, sim: QuantizationSimModel): self._sim = sim - def get_leaf_modules(self, torch_module: torch.nn.Module) -> Generator[tuple[str, torch.nn.Module], None, None]: + def get_leaf_modules(self, torch_module: torch.nn.Module) -> Generator[Tuple[str, torch.nn.Module], None, None]: """ Get all the leaf modules in the given module """ for name, module in torch_module.named_modules(): if module not in self._sim.model.modules(): diff --git a/TrainingExtensions/torch/test/python/v2/test_manual_mixed_precision.py b/TrainingExtensions/torch/test/python/v2/test_manual_mixed_precision.py index c4e4eeb4056..de0ff244dae 100644 --- a/TrainingExtensions/torch/test/python/v2/test_manual_mixed_precision.py +++ b/TrainingExtensions/torch/test/python/v2/test_manual_mixed_precision.py @@ -41,7 +41,7 @@ import pytest import torch -from torch import nn, candidate +from torch import nn from aimet_common.defs import QuantizationDataType from aimet_torch.v2.quantization.base.quantizer import QuantizerBase