Skip to content

Commit

Permalink
Initial implementation of EncodingAnalyzer and MinMaxEncodingAnalyzer (
Browse files Browse the repository at this point in the history
…#2630)

* Added implementation of EncodingAnalyzer and MinMaxEncodingAnalyzer

Signed-off-by: Priyanka Dangi <[email protected]>
  • Loading branch information
quic-pdangi authored Jan 4, 2024
1 parent ce3dafe commit 4870791
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023 Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,94 +34,104 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=all
from typing import TypeVar, Generic, Tuple, Type, Optional
import abc
from dataclasses import dataclass
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=missing-docstring
# pylint: disable=no-member

import torch
""" Computes statistics and encodings """

from aimet_torch.experimental.v2.utils import reduce
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TypeVar, Generic, Tuple, Optional
import torch
from aimet_torch.experimental.v2.utils import reduce, StatisticsNotFoundError


@dataclass(frozen=True)
@dataclass
class _MinMaxRange:
min: Optional[torch.Tensor] = None
max: Optional[torch.Tensor] = None


class _Histogram:
# TODO
...

pass

_Statistics = TypeVar('_Statistics', _MinMaxRange, _Histogram)

class _Observer(Generic[_Statistics], ABC):
"""
Observes and gathers statistics
"""
def __init__(self, min_max_shape: torch.Tensor):
self.shape = min_max_shape

class _Observer(Generic[_Statistics], abc.ABC):
def __init__(self, shape):
self.shape = shape

@abc.abstractmethod
def collect_stats(self, x: torch.Tensor) -> _Statistics:
...
@abstractmethod
def collect_stats(self, input_tensor: torch.Tensor) -> _Statistics:
pass

@abc.abstractmethod
@abstractmethod
def merge_stats(self, stats: _Statistics):
...
pass

@abc.abstractmethod
@abstractmethod
def reset_stats(self):
...
pass

@abc.abstractmethod
@abstractmethod
def get_stats(self) -> _Statistics:
...
pass


class _MinMaxObserver(_Observer[_MinMaxRange]):
def __init__(self, shape):
super().__init__(shape)
"""
Observer for Min-Max calibration technique
"""
def __init__(self, min_max_shape: torch.Tensor):
super().__init__(min_max_shape)
self.stats = _MinMaxRange()

@torch.no_grad()
def collect_stats(self, x: torch.Tensor) -> _MinMaxRange:
min = reduce(x, shape=self.shape, reduce_op=torch.min).values
max = reduce(x, shape=self.shape, reduce_op=torch.max).values
return _MinMaxRange(min, max)
def collect_stats(self, input_tensor: torch.Tensor) -> _MinMaxRange:
new_min = reduce(input_tensor, shape=self.shape, reduce_op=torch.min).values
new_max = reduce(input_tensor, shape=self.shape, reduce_op=torch.max).values
return _MinMaxRange(new_min, new_max)

@torch.no_grad()
def merge_stats(self, new_stats: _MinMaxRange):
min = self.stats.min
updated_min = self.stats.min
if new_stats.min is not None:
if min is None:
min = new_stats.min.clone()
if updated_min is None:
updated_min = new_stats.min.clone()
else:
min = torch.minimum(min, new_stats.min)
updated_min = torch.minimum(updated_min, new_stats.min)

max = self.stats.max
updated_max = self.stats.max
if new_stats.max is not None:
if max is None:
max = new_stats.max.clone()
if updated_max is None:
updated_max = new_stats.max.clone()
else:
max = torch.maximum(max, new_stats.max)
updated_max = torch.maximum(updated_max, new_stats.max)

self.stats = _MinMaxRange(min, max)
self.stats = _MinMaxRange(updated_min, updated_max)

def reset_stats(self):
self.stats = _MinMaxRange()

def get_stats(self) -> _MinMaxRange:
return self.stats


class _HistogramObserver(_Observer[_Histogram]):
def __init__(self, shape):
# TODO
raise NotImplementedError
"""
Observer for Histogram based calibration techniques (percentile, MSE)
"""
def __init__(self, min_max_shape: torch.Tensor):
super().__init__(min_max_shape)
self.stats = _Histogram()

@torch.no_grad()
def collect_stats(self, x: torch.Tensor) -> _Histogram:
def collect_stats(self, input_tensor: torch.Tensor) -> _Histogram:
# TODO
raise NotImplementedError

Expand All @@ -131,94 +141,133 @@ def merge_stats(self, new_stats: _Histogram):
raise NotImplementedError

def reset_stats(self):
# TODO
raise NotImplementedError
self.stats = _Histogram()

def get_stats(self) -> _Histogram:
# TODO
raise NotImplementedError


class _EncodingAnalyzer(Generic[_Statistics], abc.ABC):
observer_cls: Type[_Observer[_Statistics]]
return self.stats

def __init__(self, shape):
self.observer = self.observer_cls(shape)
class CalibrationMethod(Enum):
"""
Enum for quantization calibration method
"""
MinMax = 0
SQNR = 1
Percentile = 2
MSE = 3

def get_encoding_analyzer_cls(calibration_method: CalibrationMethod, min_max_shape: torch.Tensor):
"""
Instantiates an EncodingAnalyzer based on the CalibrationMethod
"""
if calibration_method == CalibrationMethod.MinMax:
return MinMaxEncodingAnalyzer(min_max_shape)
if calibration_method == CalibrationMethod.SQNR:
return SqnrEncodingAnalyzer(min_max_shape)
if calibration_method == CalibrationMethod.Percentile:
return PercentileEncodingAnalyzer(min_max_shape)
if calibration_method == CalibrationMethod.MSE:
return MseEncodingAnalyzer(min_max_shape)
return ValueError('Calibration type must be one of the following:'
'minmax, sqnr, mse, percentile')

class _EncodingAnalyzer(Generic[_Statistics], ABC):

@torch.no_grad()
def update_stats(self, x: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(x)
def update_stats(self, input_tensor: torch.Tensor) -> _Statistics:
new_stats = self.observer.collect_stats(input_tensor)
self.observer.merge_stats(new_stats)
return new_stats

def reset_stats(self) -> None:
self.observer.reset_stats()

def compute_encodings(self, symmetric: bool, bitwidth: int)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return self.compute_encodings_from_stats(self.observer.get_stats(), symmetric, bitwidth)
def compute_encodings(self, bitwidth: int, is_symmetric: bool) -> torch.Tensor:
return self.compute_encodings_from_stats(self.observer.get_stats(), bitwidth, is_symmetric)

def compute_dynamic_encodings(self, x: torch.Tensor, symmetric: bool, bitwidth: int)\
-> Tuple[torch.Tensor, torch.Tensor]:
return self.compute_encodings_from_stats(self.observer.collect_stats(x), symmetric, bitwidth)
def compute_dynamic_encodings(self, input_tensor: torch.Tensor, bitwidth: int,\
is_symmetric: bool)-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return self.compute_encodings_from_stats(
self.observer.collect_stats(input_tensor), bitwidth, is_symmetric)

@abc.abstractmethod
def compute_encodings_from_stats(self, stats: _Statistics, symmetric: bool, bitwidth: int)\
@abstractmethod
def compute_encodings_from_stats(self, stats: _Statistics, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
...

pass

class MinMaxEncodingAnalyzer(_EncodingAnalyzer[_MinMaxRange]):
observer_cls = _MinMaxObserver
"""
Encoding Analyzer for Min-Max calibration technique
"""
def __init__(self, shape):
self.observer = _MinMaxObserver(shape)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _MinMaxRange, symmetric: bool, bitwidth: int)\
def compute_encodings_from_stats(self, stats: _MinMaxRange, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if bitwidth <= 0:
raise ValueError('Bitwidth cannot be less than or equal to 0.')

if stats.min is None or stats.max is None:
return None, None
raise StatisticsNotFoundError('No statistics present to compute encodings.')

updated_min = stats.min
updated_max = stats.max

tiny_num = torch.finfo(stats.min.dtype).tiny
# enforces that 0 is within the min/max
min_with_zero = torch.minimum(stats.min, torch.zeros_like(stats.min))
max_with_zero = torch.maximum(stats.max, torch.zeros_like(stats.max))

if symmetric:
min = torch.minimum(stats.min, -stats.max)
max = torch.maximum(-stats.min, stats.max)
else:
min = stats.min
max = stats.max
# adjusts any min/max pairing that are too close
tensor_diff = (max_with_zero - min_with_zero) / ((2 **bitwidth) - 1)
update_min = torch.where(tensor_diff < tiny_num, tiny_num * (2 **(bitwidth - 1)), 0.0)
update_max = torch.where(tensor_diff < tiny_num, tiny_num * ((2 **(bitwidth - 1)) - 1), 0.0)
updated_max = max_with_zero + update_max
updated_min = min_with_zero - update_min

return min, max
if is_symmetric:
# ensures that min/max pairings are symmetric
symmetric_min = torch.minimum(updated_min, -updated_max)
symmetric_max = torch.maximum(-updated_min, updated_max)
return symmetric_min, symmetric_max

return updated_min, updated_max

class PercentileEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver
"""
Encoding Analyzer for Percentile calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


class SqnrEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver
"""
Encoding Analyzer for SQNR Calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


class MseEncodingAnalyzer(_EncodingAnalyzer[_Histogram]):
observer_cls = _HistogramObserver
"""
Encoding Analyzer for Mean Square Error (MSE) Calibration technique
"""
def __init__(self, shape):
self.observer = _HistogramObserver(shape)

@torch.no_grad()
def compute_encodings_from_stats(self, stats: _Histogram, symmetric: bool, bitwidth: int)\
def compute_encodings_from_stats(self, stats: _Histogram, bitwidth: int, is_symmetric: bool)\
-> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
# TODO
raise NotImplementedError


def get_encoding_analyzer_cls(qscheme):
if qscheme == 'minmax':
return MinMaxEncodingAnalyzer

raise ValueError
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

import torch

from aimet_torch.experimental.v2.utils import patch_attr, patch_param
from aimet_torch.experimental.v2.utils import patch_attr, patch_param, StatisticsNotFoundError
from aimet_torch.experimental.v2.quantization.encoding_analyzer import get_encoding_analyzer_cls
from aimet_torch.experimental.v2.quantization.backends import get_backend
from aimet_torch.experimental.v2.utils import ste_round
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, shape, bitwidth: int, symmetric: bool, qscheme):
self.bitwidth = bitwidth
self.symmetric = symmetric
self.qscheme = qscheme
self.encoding_analyzer = get_encoding_analyzer_cls(qscheme)(shape)
self.encoding_analyzer = get_encoding_analyzer_cls(qscheme, shape)

# Raw quantization parameters
self.register_parameter("min", None)
Expand Down Expand Up @@ -155,8 +155,8 @@ def forward_wrapper(input):
batch_statistics = self.encoding_analyzer.update_stats(input)
dynamic_min, dynamic_max =\
self.encoding_analyzer.compute_encodings_from_stats(batch_statistics,
self.symmetric,
self.bitwidth)
self.bitwidth,
self.symmetric)
with patch_param(self, 'min', dynamic_min),\
patch_param(self, 'max', dynamic_max):
return original_forward(input)
Expand All @@ -167,7 +167,10 @@ def forward_wrapper(input):
except: # pylint: disable=try-except-raise
raise
else:
min, max = self.encoding_analyzer.compute_encodings(self.symmetric, self.bitwidth)
try:
min, max = self.encoding_analyzer.compute_encodings(self.bitwidth, self.symmetric)
except StatisticsNotFoundError:
return

if min is None or max is None:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,8 @@ def ste_round(*args, **kwargs):
Applies straight-through rounding
"""
return _StraightThroughEstimator.apply(torch.round, *args, **kwargs)

class StatisticsNotFoundError(RuntimeError):
'''
Error raised when compute_encodings() is invoked without statistics
'''
Loading

0 comments on commit 4870791

Please sign in to comment.