Skip to content

Commit

Permalink
Migrate spy_auto_quant to _base.auto_quant (#3683)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Dec 20, 2024
1 parent fd0a16b commit 88cea48
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
""" Implementation of AIMET AutoQuantBase """
import abc
import copy
import contextlib
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
import functools
Expand Down Expand Up @@ -1346,3 +1347,51 @@ def _optimize_main(self, fp32_model: torch.nn.Module, target_acc: float): # pyli

raise RuntimeError("None of batchnorm folding, CLE, or Adaround "
"has been finished successfully.")


@contextlib.contextmanager
def spy_auto_quant(auto_quant: AutoQuantBase):
"""
Install a spy that collects the handles to the ptq result of
each stage of AutoQuant.
Typical usage::
>>> auto_quant = AutoQuant(...)
... with auto_quant_spy(auto_quant) as spy:
... _ = auto_quant.apply(...)
...
... for result in spy.get_all_ptq_results():
... print(result.applied_techniques)
... print(result.accuracy)
... print(result.encoding_path)
... model = result.load_model()
... ...
"""
# pylint: disable=protected-access
class Spy:
"""
Spy that collects the handles to the ptq result of
each stage of AutoQuant.
"""
def __init__(self, eval_manager):
self._eval_manager = eval_manager

def get_all_ptq_results(self) -> List[PtqResult]:
"""Return handles to the results of AutoQuant"""
if self._eval_manager is None:
return []
return [sess.ptq_result for sess in self._eval_manager._all_sessions.values()
if sess.ptq_result is not None]

spy = Spy(auto_quant.eval_manager)

_optimize_main = auto_quant._optimize_main

def _optimize_main_wrapper(fp32_model, target_acc):
return _optimize_main(fp32_model, target_acc)

try:
setattr(auto_quant, "_optimize_main", _optimize_main_wrapper)
yield spy
finally:
setattr(auto_quant, "_optimize_main", _optimize_main)
50 changes: 0 additions & 50 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

""" Implementation of AIMET AutoQuantBase and v1 AutoQuant """
import copy
import contextlib
import functools
import itertools
import os
Expand All @@ -52,7 +51,6 @@
AutoQuantBase,
_EvalManager,
_QuantSchemePair,
PtqResult,
_EvalSession,
cache,
_MixedPrecisionArgs,
Expand Down Expand Up @@ -85,54 +83,6 @@
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AutoQuant)


@contextlib.contextmanager
def spy_auto_quant(auto_quant: AutoQuantBase):
"""
Install a spy that collects the handles to the ptq result of
each stage of AutoQuant.
Typical usage::
>>> auto_quant = AutoQuant(...)
... with auto_quant_spy(auto_quant) as spy:
... _ = auto_quant.apply(...)
...
... for result in spy.get_all_ptq_results():
... print(result.applied_techniques)
... print(result.accuracy)
... print(result.encoding_path)
... model = result.load_model()
... ...
"""
# pylint: disable=protected-access
class Spy:
"""
Spy that collects the handles to the ptq result of
each stage of AutoQuant.
"""
def __init__(self, eval_manager):
self._eval_manager = eval_manager

def get_all_ptq_results(self) -> List[PtqResult]:
"""Return handles to the results of AutoQuant"""
if self._eval_manager is None:
return []
return [sess.ptq_result for sess in self._eval_manager._all_sessions.values()
if sess.ptq_result is not None]

spy = Spy(auto_quant.eval_manager)

_optimize_main = auto_quant._optimize_main

def _optimize_main_wrapper(fp32_model, target_acc):
return _optimize_main(fp32_model, target_acc)

try:
setattr(auto_quant, "_optimize_main", _optimize_main_wrapper)
yield spy
finally:
setattr(auto_quant, "_optimize_main", _optimize_main)


class AutoQuant(AutoQuantBase): # pylint: disable=too-many-instance-attributes
"""
Integrate and apply post-training quantization techniques.
Expand Down

0 comments on commit 88cea48

Please sign in to comment.