diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/auto_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/auto_quant.py index 35d34f1d96..1c22149e3d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/auto_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/auto_quant.py @@ -39,6 +39,7 @@ """ Implementation of AIMET AutoQuantBase """ import abc import copy +import contextlib from collections import OrderedDict, defaultdict from dataclasses import dataclass import functools @@ -1346,3 +1347,51 @@ def _optimize_main(self, fp32_model: torch.nn.Module, target_acc: float): # pyli raise RuntimeError("None of batchnorm folding, CLE, or Adaround " "has been finished successfully.") + + +@contextlib.contextmanager +def spy_auto_quant(auto_quant: AutoQuantBase): + """ + Install a spy that collects the handles to the ptq result of + each stage of AutoQuant. + + Typical usage:: + >>> auto_quant = AutoQuant(...) + ... with auto_quant_spy(auto_quant) as spy: + ... _ = auto_quant.apply(...) + ... + ... for result in spy.get_all_ptq_results(): + ... print(result.applied_techniques) + ... print(result.accuracy) + ... print(result.encoding_path) + ... model = result.load_model() + ... ... + """ + # pylint: disable=protected-access + class Spy: + """ + Spy that collects the handles to the ptq result of + each stage of AutoQuant. + """ + def __init__(self, eval_manager): + self._eval_manager = eval_manager + + def get_all_ptq_results(self) -> List[PtqResult]: + """Return handles to the results of AutoQuant""" + if self._eval_manager is None: + return [] + return [sess.ptq_result for sess in self._eval_manager._all_sessions.values() + if sess.ptq_result is not None] + + spy = Spy(auto_quant.eval_manager) + + _optimize_main = auto_quant._optimize_main + + def _optimize_main_wrapper(fp32_model, target_acc): + return _optimize_main(fp32_model, target_acc) + + try: + setattr(auto_quant, "_optimize_main", _optimize_main_wrapper) + yield spy + finally: + setattr(auto_quant, "_optimize_main", _optimize_main) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py b/TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py index 1539276cee..8887f3a138 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py @@ -38,7 +38,6 @@ """ Implementation of AIMET AutoQuantBase and v1 AutoQuant """ import copy -import contextlib import functools import itertools import os @@ -52,7 +51,6 @@ AutoQuantBase, _EvalManager, _QuantSchemePair, - PtqResult, _EvalSession, cache, _MixedPrecisionArgs, @@ -85,54 +83,6 @@ _logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AutoQuant) -@contextlib.contextmanager -def spy_auto_quant(auto_quant: AutoQuantBase): - """ - Install a spy that collects the handles to the ptq result of - each stage of AutoQuant. - - Typical usage:: - >>> auto_quant = AutoQuant(...) - ... with auto_quant_spy(auto_quant) as spy: - ... _ = auto_quant.apply(...) - ... - ... for result in spy.get_all_ptq_results(): - ... print(result.applied_techniques) - ... print(result.accuracy) - ... print(result.encoding_path) - ... model = result.load_model() - ... ... - """ - # pylint: disable=protected-access - class Spy: - """ - Spy that collects the handles to the ptq result of - each stage of AutoQuant. - """ - def __init__(self, eval_manager): - self._eval_manager = eval_manager - - def get_all_ptq_results(self) -> List[PtqResult]: - """Return handles to the results of AutoQuant""" - if self._eval_manager is None: - return [] - return [sess.ptq_result for sess in self._eval_manager._all_sessions.values() - if sess.ptq_result is not None] - - spy = Spy(auto_quant.eval_manager) - - _optimize_main = auto_quant._optimize_main - - def _optimize_main_wrapper(fp32_model, target_acc): - return _optimize_main(fp32_model, target_acc) - - try: - setattr(auto_quant, "_optimize_main", _optimize_main_wrapper) - yield spy - finally: - setattr(auto_quant, "_optimize_main", _optimize_main) - - class AutoQuant(AutoQuantBase): # pylint: disable=too-many-instance-attributes """ Integrate and apply post-training quantization techniques.