diff --git a/requirements/requirements-vision.txt b/requirements/requirements-vision.txt index 0e74a8966..c86540c5c 100644 --- a/requirements/requirements-vision.txt +++ b/requirements/requirements-vision.txt @@ -1,2 +1,3 @@ +accelerate torchvision tqdm diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 55ef86a31..8387c839c 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -11,8 +11,10 @@ import brevitas from brevitas import config +from brevitas.core.function_wrapper.ops_ste import TensorClampSte from brevitas.core.utils import SliceTensor from brevitas.function.ops_ste import floor_ste +from brevitas.function.ops_ste import round_ste class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule): @@ -28,12 +30,17 @@ def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = self.learned_round_gamma = learned_round_gamma @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: - p = torch.sigmoid(x) + def forward(self, p: torch.Tensor) -> torch.Tensor: + p = torch.sigmoid(p) p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma p = torch.clamp(p, 0.0, 1.0) + if not self.training: + return p > 0.5 return p + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + class LearnedRoundSigmoid(brevitas.jit.ScriptModule): """ @@ -47,10 +54,37 @@ def __init__(self, learned_round_temperature: float = 1.) -> None: self.learned_round_temperature = learned_round_temperature @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: - p = torch.sigmoid(x / self.learned_round_temperature) + def forward(self, p: torch.Tensor) -> torch.Tensor: + if not self.training: + return p > 0 + p = torch.sigmoid(p / self.learned_round_temperature) return p + @brevitas.jit.script_method + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + + +class LearnedRoundIdentity(brevitas.jit.ScriptModule): + """ + Implementation for LearnedRound learned parameter + Adapted from https://arxiv.org/abs/2309.05516 + """ + + def __init__(self) -> None: + super(LearnedRoundIdentity, self).__init__() + self.tensor_clamp = TensorClampSte() + self.upper_lower_bound = brevitas.jit.Attribute(0.5, float) + + def forward(self, p: torch.Tensor) -> torch.Tensor: + return self.tensor_clamp( + p, + min_val=torch.tensor(-self.upper_lower_bound).type_as(p), + max_val=torch.tensor(self.upper_lower_bound).type_as(p)) + + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return round_ste(x + p) + class LearnedRoundSte(brevitas.jit.ScriptModule): """ @@ -72,17 +106,10 @@ def __init__( @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> torch.Tensor: - p = self.p_forward() + p = self.learned_round_impl(self.value) p = self.tensor_slicer(p) - return floor_ste(x) + p.to(x.dtype) - - def p_forward(self): - # In eval mode, performs true quantization, otherwise "soft" quantization - if not self.training: - p = (self.value > 0) - else: - p = self.learned_round_impl(self.value) - return p + p = (p.to(x.dtype)).view_as(x) + return self.learned_round_impl.round_forward(x, p) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 92e3da2bf..b16faaf6b 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -24,11 +24,11 @@ from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP from brevitas.graph.gpxq import SUPPORTED_TCONV_OP import brevitas.nn as qnn from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.utils.torch_utils import StopFwdException class GPFQ(GPxQ): diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 667e47d40..6b1945dcc 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -18,9 +18,9 @@ from brevitas import torch_version from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn +from brevitas.utils.torch_utils import StopFwdException class GPTQ(GPxQ): diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index e71a273c3..2af0d7f2d 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -28,10 +28,6 @@ SUPPORTED_CONV_OP = (qnn.QuantConv1d, qnn.QuantConv2d, qnn.QuantConv3d, *SUPPORTED_TCONV_OP) -class StopFwdException(Exception): - pass - - @dataclass class LayerHandler: layer_names: Set = field(default_factory=set) diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index 129a55252..fbac29176 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -53,6 +53,7 @@ class LearnedRoundImplType(AutoName): """ HARD_SIGMOID = auto() SIGMOID = auto() + IDENTITY = auto() class ScalingImplType(AutoName): diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py new file mode 100644 index 000000000..70090916a --- /dev/null +++ b/src/brevitas/optim/sign_sgd.py @@ -0,0 +1,451 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +from torch import Tensor + +try: + from torch.optim.optimizer import _default_to_fused_or_foreach + from torch.utils._foreach_utils import _get_fused_kernels_supported_devices +except: + _default_to_fused_or_foreach = None + _get_fused_kernels_supported_devices = None +try: + from torch.optim.optimizer import _use_grad_for_differentiable +except: + # Ensure backward compatibility with PyTorch < 1.13.0 + _use_grad_for_differentiable = torch.no_grad + +from torch.optim.optimizer import Optimizer + +__all__ = ["SignSGD", "sign_sgd"] + + +class SignSGD(Optimizer): + """Implements signed stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma \text{sign}(g_t) \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \text{sign}(g_t) \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + Example: + >>> optimizer = torch.optim.SignSGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all(p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"]): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}.") + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group(group, params, grads, momentum_buffer_list) + + sign_sgd( + params, + grads, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +def sign_sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, +): + r"""Functional API that performs Sign SGD algorithm computation. + + See :class:`~torch.optim.SignSGD` for details. + """ + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting() and _default_to_fused_or_foreach is not None: + fused, foreach = _default_to_fused_or_foreach( + params, differentiable=False, use_fused=False + ) + else: + foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sign_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sign_sgd + else: + func = _single_tensor_sign_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _single_tensor_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(grad).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + param.add_(torch.sign(grad), alpha=-lr) + + +def _multi_tensor_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], + with_indices=True # type: ignore[list-item] + ) + for ( + device_params, + device_grads, + device_momentum_buffer_list, + ), indices in grouped_tensors.values(): + device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(device_momentum_buffer_list[i]) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[ + indices[i]] = torch.clone(device_grads[i]).detach() + else: + buf = device_momentum_buffer_list[i] + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + # handle internal item() call if lr is a tensor + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): + grads_x_lr = torch._foreach_mul(torch._foreach_sign(device_grads), -lr) + torch._foreach_add_(device_params, grads_x_lr) + else: + torch._foreach_add_(device_params, torch._foreach_sign(device_grads), alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(torch.sign(device_grads[i]), alpha=-lr) + + +def _fused_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + raise NotImplementedError("Fused SignSGD is not implemented.") diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index a4930e43d..69b4c9438 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -4,6 +4,7 @@ from brevitas.core.bit_width import * from brevitas.core.function_wrapper import * from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid +from brevitas.core.function_wrapper.learned_round import LearnedRoundIdentity from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.core.function_wrapper.stochastic_round import StochasticRoundSte @@ -147,6 +148,8 @@ def learned_round_impl(learned_round_impl_type): return LearnedRoundSigmoid if learned_round_impl_type == LearnedRoundImplType.HARD_SIGMOID: return LearnedRoundHardSigmoid + if learned_round_impl_type == LearnedRoundImplType.IDENTITY: + return LearnedRoundIdentity @value def learned_round_init(tracked_parameter_list): diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 8942c513a..461c7bc92 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -11,6 +11,11 @@ from brevitas.function.ops_ste import floor_ste +class StopFwdException(Exception): + """Used to throw and catch an exception to stop traversing the graph.""" + pass + + class TupleSequential(Sequential): def output(self, mod, input): diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..876837081 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -405,8 +405,13 @@ def offload_model( device_map = infer_fx_auto_device_map(model, memory_map) offload_call_function(model, device_map) else: + # Some models do no have the attribute _no_split_modules, so a check is needed to prevent + # this call to crash. device_map = infer_auto_device_map( - model, memory_map, no_split_module_classes=model._no_split_modules) + model, + memory_map, + no_split_module_classes=model._no_split_modules + if hasattr(model, "_no_split_modules") else None) model = dispatch_model(model, device_map) diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py new file mode 100644 index 000000000..6fb929136 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -0,0 +1,181 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod +from typing import Callable, Dict, Generator, List, Optional, Tuple, Type + +import torch +from torch import nn +import torch.nn.functional as F + +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import FloatToIntImplType +from brevitas.inject.enum import LearnedRoundImplType +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL + + +class LearnedRoundLoss(ABC): + + @abstractmethod + def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: + pass + + @abstractmethod + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + pass + + @abstractmethod + def format_loss_components(self, *args) -> str: + pass + + +def learned_round_value_init_non_linear( + layer: nn.Module, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1, + **learned_round_impl_kwargs, +) -> torch.Tensor: + floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) + delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight + value = -torch.log((learned_round_zeta - learned_round_gamma) / + (delta - learned_round_gamma) - 1) + return value + + +def learned_round_value_init_linear( + layer: nn.Module, + **learned_round_impl_kwargs, +) -> torch.Tensor: + value = torch.zeros_like(layer.weight.data) + return value + + +LEARNED_ROUND_VALUE_INIT_MAP = { + LearnedRoundImplType.HARD_SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.IDENTITY.value: learned_round_value_init_linear,} + + +class LearnedRound(ABC): + + def __init__( + self, + learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, + learned_round_value_init_fn: Optional[Callable] = None, + **learned_round_impl_kwargs, + ) -> None: + self.learned_round_impl_type = learned_round_impl_type + self.learned_round_value_init_fn = learned_round_value_init_fn + self.learned_round_impl_kwargs = learned_round_impl_kwargs + + def learned_round_value_init( + self, + layer: nn.Module, + ) -> torch.Tensor: + # A custom initialization function for the learned round parameter can be passed + if self.learned_round_value_init_fn is not None: + return self.learned_round_value_init_fn(layer, **self.learned_round_impl_kwargs) + # If not provided, the default function, as defined in LEARNED_ROUND_VALUE_INIT_MAP + # is leveraged + return LEARNED_ROUND_VALUE_INIT_MAP[self.learned_round_impl_type.value]( + layer, **self.learned_round_impl_kwargs) + + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + value = self.learned_round_value_init(layer) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=self.learned_round_impl_type, + learned_round_init=value, + **self.learned_round_impl_kwargs, + ) + layer.weight_quant.init_tensor_quant(preserve_state_dict=True) + + def insert_learned_round_quantizers(self, model: nn.Module) -> None: + for module in model.modules(): + if isinstance(module, QuantWBIOL) and len( + self.return_learned_round_quantizers(module)) == 0: + self._insert_learned_round_quantizer_to_layer(module) + + def return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: + return [module for module in block.modules() if isinstance(module, LearnedRoundSte)] + + +class LinearTempDecay: + + def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): + self.t_max = t_max + self.start_decay = rel_start_decay * t_max + self.start_b = start_b + self.end_b = end_b + + def __call__(self, t): + if t < self.start_decay: + return self.start_b + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) + + +class RegularisedMSELoss(LearnedRoundLoss): + + def __init__( + self, + module: nn.Module, + learned_round_modules: List[nn.Module], + weight: float = 0.01, + max_count: int = 1000, + b_range: Tuple = (20, 2), + warmup: float = 0.2, + decay_start: float = 0.0, + **kwargs) -> None: + # This loss operates in a layer-wise manner, so integrity needs to be checked + assert isinstance(module, QuantWBIOL), "Regularised MSE loss can only accept a single QuantWBIOL layer." + assert len(learned_round_modules) == 1, "Regularised MSE loss can only accept a single learned round module." + + self.weight = weight + self.module = module + self.loss_start = max_count * warmup + self.temp_decay = LinearTempDecay( + max_count, + start_b=b_range[0], + end_b=b_range[1], + rel_start_decay=warmup + (1.0 - warmup) * decay_start) + self.iter = 0 + self.learned_round_module = learned_round_modules[0] + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + self.iter += 1 + + rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() + + if self.iter < self.loss_start: + b = self.temp_decay(self.iter) + round_loss = 0 + else: # 1 - |(h-0.5)*2|**b + b = self.temp_decay(self.iter) + round_vals = self.learned_round_module.p_forward() + round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() + + total_loss = rec_loss + round_loss + return total_loss, (total_loss, rec_loss, round_loss, b) + + def format_loss_components(self, loss: float, rec_loss: float, round_loss: float, b) -> str: + return "Loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( + loss, + rec_loss.detach().cpu().item(), + round_loss if isinstance(round_loss, float) else round_loss.detach().cpu().item(), + b) + + +class MSELoss(LearnedRoundLoss): + + def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: + pass + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + loss = F.mse_loss(pred, tgt) + return loss, (loss.detach().cpu().item(),) + + def format_loss_components(self, loss: float) -> str: + return "Loss = {:.4f}".format(loss) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py new file mode 100644 index 000000000..d7763f957 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -0,0 +1,672 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Adapted from https://github.com/intel/auto-round, released under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +from abc import ABC +from abc import abstractmethod +import copy +from functools import partial +import itertools +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +import warnings + +from accelerate.utils.operations import send_to_device +import torch +from torch import autocast +from torch import nn +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from brevitas import config +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase +from brevitas.utils.torch_utils import StopFwdException +from brevitas_examples.common.accelerate_utils.accelerate import offload_model +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss + +config.IGNORE_MISSING_KEYS = True + + +def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], + bool]) -> List[nn.Module]: + blocks = [] + + # Iterating over .modules() might have been more readable but + # with this recursive implementation, once a block is reached, + # its subtree of modules is not expanded. + def _get_blocks(module: nn.Module): + for module_name, module_child in module.named_children(): + if block_check_fn(module_child, module_name): + blocks.append(module_child) + else: + _get_blocks(module_child) + + # Run recursive function that updates the list blocks + _get_blocks(model) + return blocks + + +def return_scale_parameters(block: nn.Module) -> List[nn.Parameter]: + + scale_parameters = [] + + def _get_scale_parameters(module: nn.Module): + for module_child in module.children(): + if isinstance(module, WeightQuantProxyFromInjectorBase): + for submodule_name, submodule in module_child.named_parameters(): + if submodule_name.endswith('scaling_impl.value'): + scale_parameters.append(submodule) + else: + _get_scale_parameters(module_child) + + # Run recursion from block + _get_scale_parameters(block) + return scale_parameters + + +class Cache(ABC): + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def store_inputs(self, args: Any, kwargs: Any) -> None: + pass + + @abstractmethod + def store_output(self, output: Any) -> None: + pass + + @abstractmethod + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + pass + + @abstractmethod + def initialize_cache(self) -> None: + pass + + @abstractmethod + def clear_cache(self) -> None: + pass + + +class DataSaverHook: + + def __init__( + self, + cache: Cache, + store_inputs: bool = True, + store_output: bool = True, + keep_gpu: bool = True, + ) -> None: + self.cache = cache + self.store_inputs = store_inputs + self.store_output = store_output + self.keep_gpu = keep_gpu + + def __call__(self, module, args, kwargs, output) -> None: + if self.store_inputs: + if not self.keep_gpu: + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + self.cache.store_inputs(args, kwargs) + if self.store_output: + if not self.keep_gpu: + output = send_to_device(output, 'cpu') + self.cache.store_output(output) + + raise StopFwdException + + +def save_inputs_output( + model: nn.Module, + model_forward: Callable, + module: nn.Module, + dataloader: DataLoader, + cache: Cache, + store_inputs: bool = True, + store_output: bool = False, + keep_gpu: bool = True, + disable_quant: bool = False) -> None: + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + data_saver = DataSaverHook( + cache, store_inputs=store_inputs, store_output=store_output, keep_gpu=keep_gpu) + handle = module.register_forward_hook(data_saver, with_kwargs=True) + with torch.no_grad(): + for inps in dataloader: + try: + model_forward(model, inps) + except StopFwdException: + pass + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + + +class LearnedRoundOptimizer: + + def __init__( + self, + learned_round: LearnedRound, + learned_round_loss_class: Type[LearnedRoundLoss], + optimizer_class: Type[Optimizer], + *, + scale_optimizer_class: Optional[Type[Optimizer]] = None, + lr_scheduler_class: Optional[Type] = None, + optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, + batch_size: float = 8, + iters: int = 200, + learn_scale: bool = False, + use_best_model: bool = True, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000., + learned_round_loss_kwargs: Optional[Dict] = None, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + ) -> None: + # Verify that an optimizer is passed for optimizing the scale if learn_scale=True + assert not (learn_scale and scale_optimizer_class is None), "An optimizer needs to be passed for the scale if learn_scale is set to True." + + self.learned_round = learned_round + self.optimizer_class = optimizer_class + self.scale_optimizer_class = scale_optimizer_class + self.lr_scheduler_class = lr_scheduler_class + self.optimizer_lr = optimizer_lr + self.optimizer_scale_lr = optimizer_scale_lr + self.batch_size = batch_size + self.iters = iters + self.learn_scale = learn_scale + self.use_best_model = use_best_model + self.use_amp = use_amp + self.amp_dtype = amp_dtype + self.loss_scaling_factor = loss_scaling_factor + self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + self.lr_scheduler_kwargs = {} if lr_scheduler_kwargs is None else lr_scheduler_kwargs + self.lr_scheduler_kwargs["total_iters"] = self.iters + + learned_round_loss_kwargs = {} if learned_round_loss_kwargs is None else learned_round_loss_kwargs + self.learned_round_loss_init = partial( + learned_round_loss_class, **learned_round_loss_kwargs) + + @torch.no_grad() + def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: + for n, m in block.named_modules(): + if n in round_params: + m.load_state_dict(round_params[n]) + + @torch.no_grad() + def _collect_round_params(self, block: nn.Module) -> Dict: + params = {} + for n, m in block.named_modules(): + if isinstance(m, LearnedRoundSte): + params[n] = copy.deepcopy(m.state_dict()) + return params + + def _optim_step(self, *optimizers: Optimizer) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + + def _lr_sched_step(self, *lr_schedulers: Any) -> None: + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() + + def _step(self, optimizers: List[Optimizer], lr_schedulers: List[Any]) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() + + def _populate_cache( + self, + cache: Cache, + model: nn.Module, + model_forward: nn.Module, + block: nn.Module, + data_loader: DataLoader, + keep_gpu: bool = True, + capture_quant_input: bool = True, + capture_quant_output: bool = False, + ) -> None: + # Populate the cache with new inputs and outputs + save_inputs_output( + model, + model_forward, + block, + data_loader, + cache, + store_inputs=True, + store_output=capture_quant_input == capture_quant_output, + keep_gpu=keep_gpu, + disable_quant=not capture_quant_input, + ) + if capture_quant_input != capture_quant_output: + save_inputs_output( + model, + model_forward, + block, + data_loader, + cache, + store_inputs=False, + store_output=True, + keep_gpu=keep_gpu, + disable_quant=not capture_quant_output, + ) + + def _optimize_learned_round_block( + self, + block: nn.Module, + block_learned_round_modules: List[nn.Module], + cache: Cache, + block_loss: LearnedRoundLoss, + block_forward: Callable, + scale_params: Optional[nn.Parameter] = None, + ) -> Tuple[float, float, int]: + # Move block to GPU if available + if torch.cuda.is_available(): + try: + block.cuda() + except RuntimeError as exc: + if 'out of memory' in str(exc): + warnings.warn( + "Out of memory error was raised when moving the block to GPU. Defaulting to CPU." + ) + else: + raise exc + + # Initialize optimizer and LR scheduler + optimizer = self.optimizer_class( + itertools.chain( + *[ + block_learned_round_module.parameters() + for block_learned_round_module in block_learned_round_modules]), + lr=self.optimizer_lr, + **self.optimizer_kwargs, + ) + lr_scheduler = ( + self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) + if self.lr_scheduler_class else None) + + # Initialize optimizer/LR scheduler for the scale parameters if enabled + if self.learn_scale and scale_params is not None: + optimizer_scale = self.scale_optimizer_class( + scale_params, + lr=self.optimizer_scale_lr, + **self.optimizer_kwargs, + ) + lr_scheduler_scale = ( + self.lr_scheduler_class( + optimizer_scale, start_factor=1, end_factor=0, total_iters=600) + if self.lr_scheduler_class else None) + else: + optimizer_scale = None + lr_scheduler_scale = None + + # Variables needed for printing + best_loss = torch.finfo(torch.float).max + init_loss = -1.0 + last_best_iter = self.iters + + # Dictionary to store the rounding parameters yielding the lowest + # training loss + optimal_rounding_params = {} + torch.autograd.set_detect_anomaly(True) + n_samples = len(cache) + pbar = tqdm(range(self.iters), desc='') + for i in pbar: + # Sample mini-batch from cache + idxs = torch.randperm(n_samples)[:self.batch_size] + inputs, fp_outs = cache.sample_batch(idxs) + + # Run block forward to obtain quant outputs + quant_outs = block_forward(block, inputs) + fp_outs = send_to_device(fp_outs, quant_outs.device) + if self.use_amp: + with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", + dtype=self.amp_dtype): + loss, loss_components = block_loss(quant_outs, fp_outs) + else: + loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) + + # Save best parameters before taking gradient step + curr_loss = loss.detach().cpu().item() + init_loss = curr_loss if i == 0 else init_loss + if loss < best_loss: + best_loss = curr_loss + last_best_iter = i + 1 + if self.use_best_model: + optimal_rounding_params = self._collect_round_params(block) + + # Scale loss and perform gradient step + loss = loss * self.loss_scaling_factor + loss.backward() + self._optim_step(optimizer, optimizer_scale) + self._lr_sched_step(lr_scheduler, lr_scheduler_scale) + + # Update progress bar + pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) + + # Make sure no updates are received in the progress bar + pbar.close() + + if self.use_best_model: + self._load_round_params(block, optimal_rounding_params) + else: + # Override if the model with the lowest training error is not used + best_loss = curr_loss + last_best_iter = self.iters + + # Move the block back to CPU + block.cpu() + + return init_loss, best_loss, last_best_iter + + def apply_learned_round( + self, + model: nn.Module, + model_forward: Callable, + block_forward: Callable, + data_loader: DataLoader, + cache: Cache, + get_blocks_fn: Callable, + model_prepare_fn: Optional[Callable] = None, + model_finish_fn: Optional[Callable] = None, + keep_gpu: bool = True) -> None: + + # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs + model_dict = None if model_prepare_fn is None else model_prepare_fn(model) + + # Insert quantizers within the appropiate model blocks + self.learned_round.insert_learned_round_quantizers(model) + + # Retrieve blocks using the appropiate function to check blocks + blocks = get_blocks_fn(model) + + print(f"Total Iterations per block {self.iters}") + print(f"Number of blocks {len(blocks)}") + + # Initialize cache to store partial inputs and outputs for each block + cache.initialize_cache() + + # Iterate over blocks and optimise the rounding parameters within each of them + for block_idx, block in enumerate(blocks): + # Distribute the model across devices to run a forward pass to capture + # inputs/outputs to the given block + model = offload_model(model) + # Cache needs to be cleared before populating it with the inputs and outputs + # to the block under optimization. + self._populate_cache( + cache, + model, + model_forward, + block, + data_loader, + keep_gpu=keep_gpu, + capture_quant_input=True, + capture_quant_output=False, + ) + # Remove hooks needed to offload the model blocks to cpu + remove_hooks(model) + + # Retrieve scales + scale_params = return_scale_parameters(block) + + # The parameters of the block that are not part of the rounding quantizers + # need to be frozen, as only the rounding needs to be optimized. + block.eval() + for params in block.parameters(): + params.requires_grad = False + # However, the rounding parameters are tuned + block_learned_round_modules = self.learned_round.return_learned_round_quantizers(block) + for block_learned_round_module in block_learned_round_modules: + block_learned_round_module.train() + for params in block_learned_round_module.parameters(): + params.requires_grad = True + # As well as the scale parameters, if enabled + if self.learn_scale: + for params in scale_params: + params.requires_grad = True + + # Move block to GPU if available + if torch.cuda.is_available(): + block.cuda() + + # Loss function for computing the rounding loss within each block + block_loss = self.learned_round_loss_init( + block, + block_learned_round_modules, + ) + + # Optimize block rounding + init_loss, best_loss, last_best_iter = self._optimize_learned_round_block( + block=block, + block_learned_round_modules=block_learned_round_modules, + cache=cache, + block_loss=block_loss, + block_forward=block_forward, + scale_params=scale_params, + ) + + print( + f"Quantized block {block_idx+1}/{len(blocks)}, " + f"initial loss: {init_loss:.6f}, best loss: {best_loss:.6f}, at iteration {last_best_iter}." + ) + + # After finishing the optimization, the block rounding parameters are frozen + for block_learned_round_module in block_learned_round_modules: + block_learned_round_module.eval() + for params in block_learned_round_module.parameters(): + params.requires_grad = False + for params in scale_params: + params.requires_grad = False + + # Move the block back to CPU + block.cpu() + + # Reset cache after optimisation + cache.clear_cache() + + # The original configuration of the model is restored after finishing the optimization + if model_finish_fn is not None: + model_finish_fn(model, model_dict) diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py new file mode 100644 index 000000000..d9ece26d2 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +import warnings + +import torch +from torch.optim.optimizer import Optimizer + +from brevitas.inject.enum import LearnedRoundImplType +from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss +from brevitas_examples.common.learned_round.learned_round_method import MSELoss +from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss + +LEARNED_ROUND_MAP = { + "linear_round": LearnedRoundImplType.IDENTITY, + "hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID, + "sigmoid_round": LearnedRoundImplType.SIGMOID,} +LEARNED_ROUND_LOSS_MAP = { + "mse": MSELoss, + "regularised_mse": RegularisedMSELoss,} +OPTIMIZER_MAP = { + "sign_sgd": SignSGD,} +LR_SCHEDULER_MAP = {} + + +def parse_learned_round(learned_round_str: str) -> LearnedRound: + if learned_round_str not in LEARNED_ROUND_MAP: + raise ValueError(f"Learned round method {learned_round_str} is not available.") + return LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round_str]) + + +def parse_learned_round_loss_class(learned_round_loss_str: str) -> Type[LearnedRoundLoss]: + if learned_round_loss_str not in LEARNED_ROUND_LOSS_MAP: + raise ValueError(f"Learned round loss {learned_round_loss_str} is not available.") + return LEARNED_ROUND_LOSS_MAP[learned_round_loss_str] + + +def parse_optimizer_class(optimizer_str: str) -> Type[Optimizer]: + if optimizer_str in OPTIMIZER_MAP: + optimizer_class = OPTIMIZER_MAP[optimizer_str] + else: + optimizer_keys = [ + optimizer_key for optimizer_key in torch.optim.__dict__.keys() + # Check for making sure that only valid Optimizer implementations are + # retrived, when matching with the string passed by the user + if ( + # Verify that the key stars with the one passed by the user + optimizer_key.lower().startswith(optimizer_str.lower()) and + # Verify that key corresponds to a class + isinstance(torch.optim.__dict__[optimizer_key], type) and + # Make sure the abstract class is not used + optimizer_key != "Optimizer" and + # An optimizer implements zero_grad and step. Check that this + # is the case for the class retrieved from torch.optim + hasattr(torch.optim.__dict__[optimizer_key], 'step') and + callable(torch.optim.__dict__[optimizer_key].step) and + hasattr(torch.optim.__dict__[optimizer_key], 'zero_grad') and + callable(torch.optim.__dict__[optimizer_key].zero_grad))] + if len(optimizer_keys) == 0: + raise ValueError(f"{optimizer_str} is not a valid optimizer.") + else: + if len(optimizer_keys) > 1: + warnings.warn( + f"There are multiple potential matches for optimizer {optimizer_str}. " + f"Defaulting to {optimizer_keys[0]}") + optimizer_class = getattr(torch.optim, optimizer_keys[0]) + + return optimizer_class + + +def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type: + if lr_scheduler_str in LR_SCHEDULER_MAP: + lr_scheduler_class = LR_SCHEDULER_MAP[lr_scheduler_str] + else: + lr_scheduler_keys = [ + lr_scheduler_key for lr_scheduler_key in torch.optim.lr_scheduler.__dict__.keys() + # Check for making sure that only valid LRScheduler implementations are + # retrived, when matching with the string passed by the user + if ( + lr_scheduler_key.lower().startswith(lr_scheduler_str.lower()) and + # Verify that key corresponds to a class + isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and + # Make sure the abstract class is not retrieved + lr_scheduler_key != "LRScheduler" and + # A learning rate scheduler implements zero_grad and step. Check that this + # is the case for the class retrieved from torch.optim.lr_scheduler + hasattr(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], 'step') and + callable(torch.optim.lr_scheduler.__dict__[lr_scheduler_key].step))] + if len(lr_scheduler_keys) == 0: + warnings.warn( + f"There are no matches for LR scheduler {lr_scheduler_str}. " + f"No LR scheduler is going to be used.") + lr_scheduler_class = None + else: + if len(lr_scheduler_keys) > 1: + warnings.warn( + f"There are multiple potential matches for LR scheduler {lr_scheduler_str}." + f"Defaulting to {lr_scheduler_keys[0]}") + lr_scheduler_class = getattr(torch.optim.lr_scheduler, lr_scheduler_keys[0]) + + return lr_scheduler_class diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 7a5a283ea..3df861a93 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,195 +26,169 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import numpy as np +import functools +import re +from typing import Any, Callable, Dict, Optional, Tuple, Union +import warnings + +from accelerate.utils.operations import send_to_device import torch -import torch.nn.functional as F +from torch import nn +from torch.utils.data.dataloader import DataLoader from brevitas import config -from brevitas.core.function_wrapper.learned_round import LearnedRoundSte -from brevitas.graph.calibrate import disable_return_quant_tensor -from brevitas.graph.calibrate import DisableEnableQuantization -from brevitas.graph.calibrate import restore_return_quant_tensor -from brevitas.inject.enum import FloatToIntImplType -from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache +from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round +from brevitas_examples.common.learned_round.learned_round_parser import \ + parse_learned_round_loss_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_lr_scheduler_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class config.IGNORE_MISSING_KEYS = True -class StopFwdException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - pass +def is_block(module: nn.Module, module_name: str, reg_exp: str = r"layer\d+") -> bool: + return (re.search(reg_exp, module_name) is not None) + + +def is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + +BLOCK_CHECK_MAP = { + "layerwise": is_layer, + "blockwise": is_block,} -class DataSaverHook: - def __init__(self, store_output: False): - self.store_output = store_output - self.input_store = None - self.output_store = None +class CacheVision(Cache, dict): - def __call__(self, module, input_batch, output_batch): - input_batch = input_batch[0] + def __init__(self) -> None: + super().__init__() + self.batch_dim = 0 + + def store_inputs(self, args, kwargs) -> None: + input_batch = args[0] if isinstance(input_batch, QuantTensor): input_batch = input_batch.value if hasattr(input_batch, 'names') and 'N' in input_batch.names: - batch_dim = input_batch.names.index('N') - + self.batch_dim = input_batch.names.index('N') input_batch.rename_(None) - input_batch = input_batch.transpose(0, batch_dim) - if self.store_output: - output_batch.rename_(None) - output_batch = output_batch.transpose(0, batch_dim) - - if self.store_output: - self.output_store = output_batch - self.input_store = input_batch - raise StopFwdException - - -class LinearTempDecay: - - def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): - self.t_max = t_max - self.start_decay = rel_start_decay * t_max - self.start_b = start_b - self.end_b = end_b - - def __call__(self, t): - if t < self.start_decay: - return self.start_b - else: - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) - - -class Loss: - - def __init__( - self, - module, - learned_round_module, - weight=0.01, - max_count=1000, - b_range=(20, 2), - warmup=0.2, - decay_start=0.0): - self.weight = weight - self.module = module - self.loss_start = max_count * warmup - self.temp_decay = LinearTempDecay( - max_count, - start_b=b_range[0], - end_b=b_range[1], - rel_start_decay=warmup + (1.0 - warmup) * decay_start) - self.iter = 0 - self.learned_round_module = learned_round_module - - def __call__(self, pred, tgt): - self.iter += 1 - - rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() - - if self.iter < self.loss_start: - b = self.temp_decay(self.iter) - round_loss = 0 - else: # 1 - |(h-0.5)*2|**b - b = self.temp_decay(self.iter) - round_vals = self.learned_round_module.p_forward() - round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() - - total_loss = rec_loss + round_loss - return total_loss, rec_loss, round_loss, b - - -def find_learned_round_module(module): - for submodule in module.modules(): - if isinstance(submodule, LearnedRoundSte): - return submodule - return False - - -def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_gamma=-0.1): - if isinstance(layer, QuantWBIOL): - if not find_learned_round_module(layer): - floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) - delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((learned_round_zeta - learned_round_gamma) / - (delta - learned_round_gamma) - 1) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, - learned_round_gamma=learned_round_gamma, - learned_round_zeta=learned_round_zeta, - learned_round_init=value) - layer.weight_quant.init_tensor_quant(preserve_state_dict=True) - - -def split_layers(model, layers): - for module in model.children(): - if isinstance(module, QuantWBIOL): - layers.append(module) - else: - split_layers(module, layers) - - -def learned_round_iterator(layers, iters=1000): - for layer in layers: - insert_learned_round_quantizer(layer) - - for p in layer.parameters(): - p.requires_grad = False - - learned_round_module = find_learned_round_module(layer) - learned_round_module.value.requires_grad = True - layer_loss = Loss(module=layer, learned_round_module=learned_round_module, max_count=iters) - yield layer, layer_loss, learned_round_module - layer.eval() - - -def save_inp_out_data( - model, - module, - dataloader: torch.utils.data.DataLoader, - store_inp=False, - store_out=False, - keep_gpu: bool = True, - disable_quant=False): - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) + input_batch = input_batch.transpose(0, self.batch_dim) + + self["inputs"].append(input_batch) + + def store_output(self, output) -> None: + if self.batch_dim is not None: + output.rename_(None) + output = output.transpose(0, self.batch_dim) + + self["output"].append(output) + + def initialize_cache(self) -> None: + self["inputs"] = [] + self["output"] = [] + + def clear_cache(self) -> None: + del self["inputs"] + del self["output"] + self["inputs"] = [] + self["output"] = [] + + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + if isinstance(self["inputs"], list): + self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim) + if isinstance(self["output"], list): + self["output"] = torch.cat(self["output"], dim=self.batch_dim) + + return self["inputs"][indices], self["output"][indices] + + def __len__(self): + return ( + len(self["inputs"]) + if isinstance(self["inputs"], list) else self["inputs"].shape[self.batch_dim]) + + +def vision_forward(model: nn.Module, inputs: Any) -> None: device = next(model.parameters()).device - data_saver = DataSaverHook(store_output=store_out) - handle = module.register_forward_hook(data_saver) - cached = [[], []] - with torch.no_grad(): - for img, t in dataloader: - try: - _ = model(img.to(device)) - except StopFwdException: - pass - if store_inp: - if keep_gpu: - cached[0].append(data_saver.input_store.detach()) - else: - cached[0].append(data_saver.input_store.detach().cpu()) - if store_out: - if keep_gpu: - cached[1].append(data_saver.output_store.detach()) - else: - cached[1].append(data_saver.output_store.detach().cpu()) - if store_inp: - cached[0] = torch.cat([x for x in cached[0]]) - if store_out: - cached[1] = torch.cat([x for x in cached[1]]) - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - return cached + img, _ = inputs + img = send_to_device(img, device) + model(img) + + +def vision_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: + device = next(block.parameters()).device + inputs = send_to_device(inputs, device) + return block(inputs) + + +def apply_learned_round( + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 1000, + learned_round: str = "hard_sigmoid_round", + learned_round_loss: str = "regularised_mse", + block_name_attribute: str = r"layer\d+", + optimizer: str = "adam", + lr_scheduler: Optional[str] = None, + optimizer_lr: float = 1e-3, + batch_size: int = 1, + use_best_model: bool = False, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1., + learned_round_loss_kwargs: Optional[Dict] = None, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_mode: str = "layerwise", +) -> None: + # Parse strings to obtain the arguments for the optimizer + learned_round = parse_learned_round(learned_round) + learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) + optimizer_class = parse_optimizer_class(optimizer) + lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) + + # Parse method to retrieve de model blocks + if learned_round_mode == "layerwise": + block_check_fn = is_layer + elif learned_round_mode == "blockwise": + block_check_fn = functools.partial(is_block, reg_exp=block_name_attribute) + else: + block_check_fn = is_layer + warnings.warn( + f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") + get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn) + + lr_scheduler_kwargs = { + "start_factor": 1.0, + "end_factor": 0.0, + "verbose": False,} if lr_scheduler_kwargs is None else lr_scheduler_kwargs + learned_round_optimizer = LearnedRoundOptimizer( + learned_round=learned_round, + learned_round_loss_class=learned_round_loss_class, + optimizer_class=optimizer_class, + lr_scheduler_class=lr_scheduler_class, + optimizer_lr=optimizer_lr, + batch_size=batch_size, + iters=iters, + use_best_model=use_best_model, + use_amp=use_amp, + amp_dtype=amp_dtype, + loss_scaling_factor=loss_scaling_factor, + learned_round_loss_kwargs=learned_round_loss_kwargs, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler_kwargs=lr_scheduler_kwargs) + cache = CacheVision() + learned_round_optimizer.apply_learned_round( + model=model, + model_forward=vision_forward, + block_forward=vision_block_forward, + data_loader=calibration_loader, + cache=cache, + get_blocks_fn=get_blocks_fn, + keep_gpu=True, + ) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 38ed85678..bd245b7e6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -72,9 +72,6 @@ from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers # Every element of the Batch will have its own scale factor and zero point @@ -645,35 +642,6 @@ def apply_gpfq( gpfq.update() -def apply_learned_round_learning( - model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): - layers = [] - split_layers(model, layers) - print(f"Total Iterations per layer {iters}") - print(f"Number of layers {len(layers)}") - - for layer, layer_loss, learned_round_module in learned_round_iterator(layers, iters=iters): - optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) - _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - max_size = len(all_fp_out) - pbar = tqdm(range(iters), desc='') - for i in pbar: - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - layer.train() - - optimizer.zero_grad() - quant_out = layer(quant_inp) - loss, rec_loss, round_loss, b = layer_loss(quant_out, fp_out) - - loss.backward() - optimizer.step() - pbar.set_description( - "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( - loss, rec_loss, round_loss, b)) - - def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 34bdfbc96..ff221e5ba 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -21,11 +21,11 @@ from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import apply_learned_round from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate_bn from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model @@ -158,16 +158,49 @@ def validate_args(args): default=20, type=int, help='Numbers of iterations for graph equalization (default: 20)') +parser.add_argument( + '--learned-round', + default=None, + type=str, + choices=[None, 'linear_round', 'hard_sigmoid_round', 'sigmoid_round'], + help='Learned round type (default: None)') +parser.add_argument( + '--learned-round-block-name', + type=str, + default="layer\d+", + help='Block name for learned round. It works only if FX is not needed (default: %(default)s)') +parser.add_argument( + '--learned-round-loss', + default='regularised_mse', + type=str, + choices=['regularised_mse', 'mse'], + help='Learned round type (default: none)') +parser.add_argument( + '--learned-round-mode', + default='layerwise', + choices=['layerwise', 'blockwise'], + help='Learned round mode (default: none)') parser.add_argument( '--learned-round-iters', default=1000, type=int, help='Numbers of iterations for learned round for each layer (default: 1000)') +parser.add_argument( + '--learned-round-lr-scheduler', + default=None, + type=str, + choices=[None, 'linear'], + help='Learning rate scheduler for learned round (default: None)') parser.add_argument( '--learned-round-lr', default=1e-3, type=float, help='Learning rate for learned round (default: 1e-3)') +parser.add_argument( + '--learned-round-batch-size', + default=1, + type=int, + help='Learning rate for learned round (default: %(default)d)') parser.add_argument( '--act-quant-percentile', default=99.999, @@ -250,6 +283,11 @@ def validate_args(args): help= 'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)' ) +parser.add_argument( + '--optimizer', + default='adam', + choices=['adam', 'sign_sgd'], + help='Optimizer to use with learnable rounding (default: %(default)s)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( @@ -264,7 +302,6 @@ def validate_args(args): 'gpxq-create-weight-orig', default=False, help='Maintain original weights for non-quant forward pass (default: disabled)') -add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') add_bool_arg( parser, @@ -321,7 +358,7 @@ def main(): f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" - f"{'learned_round_' if args.learned_round else ''}" + f"{'learned_round' if args.learned_round is not None else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" f"{args.bias_bit_width}bias_" f"{args.weight_quant_granularity}_" @@ -344,7 +381,7 @@ def main(): f"GPFQ: {args.gpfq} - " f"GPxQ Act Order: {args.gpxq_act_order} - " f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - " - f"Learned Round: {args.learned_round} - " + f"Learned Round method: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " f"Weight scale factors type: {args.weight_quant_granularity} - " @@ -398,20 +435,21 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) elif args.target_backend == 'fx' or args.target_backend == 'layerwise': - model = preprocess_for_quantize( - model, - equalize_iters=args.graph_eq_iterations, - equalize_merge_bias=args.graph_eq_merge_bias, - merge_bn=args.merge_bn, - channel_splitting_ratio=args.channel_splitting_ratio, - channel_splitting_split_input=args.channel_splitting_split_input) + if args.learned_round_mode != "blockwise": + model = preprocess_for_quantize( + model, + equalize_iters=args.graph_eq_iterations, + equalize_merge_bias=args.graph_eq_merge_bias, + merge_bn=args.merge_bn, + channel_splitting_ratio=args.channel_splitting_ratio, + channel_splitting_split_input=args.channel_splitting_split_input) else: raise RuntimeError(f"{args.target_backend} backend not supported.") + device = (torch.device(f"cuda:{args.gpu}") if args.gpu is not None else torch.device("cpu")) + model = model.to(device=device) # If available, use the selected GPU if args.gpu is not None: - torch.cuda.set_device(args.gpu) - model = model.cuda(args.gpu) cudnn.benchmark = False if args.act_equalization is not None: @@ -477,11 +515,19 @@ def main(): if args.learned_round: print("Applying Learned Round:") - apply_learned_round_learning( - quant_model, - calib_loader, + apply_learned_round( + model=quant_model, + calibration_loader=calib_loader, iters=args.learned_round_iters, - optimizer_lr=args.learned_round_lr) + learned_round=args.learned_round, + learned_round_loss=args.learned_round_loss, + block_name_attribute=args.learned_round_block_name, + optimizer=args.optimizer, + lr_scheduler=args.learned_round_lr_scheduler, + optimizer_lr=args.learned_round_lr, + batch_size=args.learned_round_batch_size, + learned_round_mode=args.learned_round_mode, + ) if args.calibrate_bn: print("Calibrate BN:") diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 5e61306d4..26e984cfd 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -15,8 +15,8 @@ from brevitas.graph.gpfq import GPFQv2 from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.torch_utils import StopFwdException from brevitas_examples.common.axe import A2GPFQ from brevitas_examples.common.axe import A2GPTQ diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py new file mode 100644 index 000000000..fa3ca0048 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -0,0 +1,170 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +from accelerate.utils.operations import send_to_device +import torch +from torch import nn +from torch.utils.data.dataloader import DataLoader + +from brevitas.utils.python_utils import recurse_getattr +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round +from brevitas_examples.common.learned_round.learned_round_parser import \ + parse_learned_round_loss_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_lr_scheduler_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class + + +class CacheLLM(Cache, dict): + + def __init__(self) -> None: + super().__init__() + + def store_inputs(self, args, kwargs) -> None: + self["args"].append(args) + self["kwargs"].append(kwargs) + + def store_output(self, output) -> None: + if isinstance(output, (tuple, list)): + output = output[0] + self["output"].append(output) + + def initialize_cache(self) -> None: + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] + + def clear_cache(self) -> None: + del self["args"] + del self["kwargs"] + del self["output"] + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] + + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + cache_args, cache_kwargs, cache_outs = self["args"], self["kwargs"], self["output"] + # Positional arguments + args = [cache_args[i] for i in indices] + args = tuple(torch.cat(arg_tensor, dim=0) for arg_tensor in zip(*args)) + # Keyword arguments + kwargs_dict = [cache_kwargs[i] for i in indices] + kwargs = {} + for curr_dict in kwargs_dict: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + # FP outputs + outs = torch.cat([cache_outs[i] for i in indices], dim=0) + return (args, kwargs), outs + + def __len__(self): + return len(self["args"]) + + +def llm_learned_round_prepare_fn(model: nn.Module) -> None: + llm_cache_state = model.config.use_cache + model.config.use_cache = False + return llm_cache_state + + +def llm_learned_round_finish_fn(model: nn.Module, llm_cache_state: Dict) -> None: + model.config.use_cache = llm_cache_state + + +def llm_forward(model: nn.Module, inputs: Any) -> None: + device = next(model.parameters()).device + if device != torch.device("meta"): + inputs = send_to_device(inputs, device) + model(**inputs) + + +def llm_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: + device = next(block.parameters()).device + args, kwargs = inputs + args = send_to_device(args, device) + kwargs = send_to_device(kwargs, device) + out = block(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + return out + + +def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: + return recurse_getattr(model, block_name_attribute) + + +def apply_learned_round( + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 200, + learned_round: str = "linear_round", + learned_round_loss: str = "mse", + block_name_attribute: str = "layers", + optimizer: str = "sign_sgd", + optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, + batch_size: int = 8, + learn_scale: bool = False, + use_best_model: bool = True, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_loss_kwargs: Optional[Dict] = None, +) -> None: + # Parse strings to obtain the arguments for the optimizer + learned_round = parse_learned_round(learned_round) + learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) + optimizer_class = parse_optimizer_class(optimizer) + lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) + + llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute) + + lr_scheduler_kwargs = { + "start_factor": 1.0, + "end_factor": 0.0, + "verbose": False,} if lr_scheduler_kwargs is None else lr_scheduler_kwargs + learned_round_optimizer = LearnedRoundOptimizer( + learned_round=learned_round, + learned_round_loss_class=learned_round_loss_class, + optimizer_class=optimizer_class, + lr_scheduler_class=lr_scheduler_class, + optimizer_lr=optimizer_lr, + optimizer_scale_lr=optimizer_scale_lr, + batch_size=batch_size, + iters=iters, + learn_scale=learn_scale, + use_best_model=use_best_model, + use_amp=use_amp, + amp_dtype=amp_dtype, + loss_scaling_factor=loss_scaling_factor, + learned_round_loss_kwargs=learned_round_loss_kwargs, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler_kwargs=lr_scheduler_kwargs) + cache = CacheLLM() + learned_round_optimizer.apply_learned_round( + model=model, + model_forward=llm_forward, + block_forward=llm_block_forward, + data_loader=calibration_loader, + cache=cache, + get_blocks_fn=llm_block_check_fn, + model_prepare_fn=llm_learned_round_prepare_fn, + model_finish_fn=llm_learned_round_finish_fn, + keep_gpu=False, + ) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 2a71546e4..ee2f0b3e6 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,18 +1,24 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + import warnings +from packaging import version import torch +import transformers from transformers.models.opt.modeling_opt import OPTAttention -from transformers.models.opt.modeling_opt import OPTSdpaAttention from brevitas.graph import ModuleToModuleByClass from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention QUANTIZABLE_MHA_MAP = { OPTAttention: (QuantizableOPTAttention, { - 'batch_first': True}), - OPTSdpaAttention: (QuantizableOPTAttention, { 'batch_first': True}),} +if version.parse(transformers.__version__) >= version.parse('4.46.0'): + from transformers.models.opt.modeling_opt import OPTSdpaAttention + QUANTIZABLE_MHA_MAP[OPTSdpaAttention] = (QuantizableOPTAttention, {'batch_first': True}) + def replace_mha_with_quantizable_layers(model, dtype): rewriters = [] diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8b3ae4888..46d974d23 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -34,6 +34,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq +from brevitas_examples.llm.llm_quant.learned_round_utils import apply_learned_round from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch @@ -367,6 +368,22 @@ def main(args): with torch.no_grad(): model(**calibration_loader[0]) + if args.learned_round: + print("Applying learned round...") + remove_hooks(model) + apply_learned_round( + model, + calibration_loader, + iters=args.learned_round_iters, + block_name_attribute=args.gpxq_block_name, + optimizer_lr=args.learned_round_lr, + optimizer_scale_lr=args.learned_round_scale_lr, + learn_scale=args.learned_round_scale, + ) + print("Learned round applied.") + + model = offload_model(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -551,6 +568,25 @@ def parse_args(args): type=int, default=64, help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--learned-round-lr', + type=float, + default=5e-3, + help='Learning rate for learned round parameter optimization. Default: %(default)s') + parser.add_argument( + '--learned-round-scale-lr', + type=float, + default=5e-3, + help='Learning rate for scale optimization during round learning. Default: %(default)s') + parser.add_argument( + '--learned-round-iters', + type=int, + default=200, + help='Number of iterations for learned round. Default: 200.') + parser.add_argument( + '--learned-round-scale', + action='store_true', + help='Learned scale factor together with round.') parser.add_argument( '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') parser.add_argument( @@ -658,6 +694,11 @@ def parse_args(args): help= "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", ) + parser.add_argument( + '--learned-round', + default=None, + choices=[None, 'linear_round'], + help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') return parser.parse_args(args) diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index ef635c279..2701e78a3 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -1,15 +1,19 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from hypothesis import given +from hypothesis.strategies import floats import pytest import pytest_cases import torch from brevitas import config from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid +from brevitas.core.function_wrapper.learned_round import LearnedRoundIdentity from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte import brevitas.nn as qnn +from tests.brevitas.hyp_helper import two_float_tensor_random_shape_st OUT_CH = 16 IN_CH = 8 @@ -18,11 +22,78 @@ LearnedRoundSigmoid(), # Sigmoid Implementation LearnedRoundSigmoid(learned_round_temperature=2.), # Sigmoid + Temperature LearnedRoundHardSigmoid(), # Hard Sigmoid -] + LearnedRoundIdentity(),] class TestLearnedRound(): + def instantiate_learnedround_float_to_int_impl(self, impl, weights, value): + impl = LearnedRoundSte(impl, torch.full(weights.shape, 0.)) + if isinstance(impl.learned_round_impl, LearnedRoundIdentity): + min_value, max_value = torch.min(value), torch.max(value) + # Prevent division by zero when all the elements of the tensor are the same + if max_value - min_value < 1e-8: + # Make sure that the division is safe + if torch.abs(max_value) > 1e-8: + value = value / max_value - 0.5 + else: + value = (value - min_value) / (max_value - min_value) - 0.5 + # Simulate learned round + impl.value.data = value + return impl + + # NOTE: The min/max values are set to the exactly representable float32 + # closer to sys.maxsize for a given machine. + @pytest_cases.parametrize('impl', LEARNEDROUND_IMPL) + @pytest_cases.parametrize('training', [True, False]) + @given( + weights_value=two_float_tensor_random_shape_st( + min_val=-9.223372036854776e+18, max_val=9.223372036854776e+18)) + def test_learnedround(self, impl, training, weights_value): + # Unpack tuple of hypothesis generated tensors + weights, value = weights_value + # Instantiate LearnedRoundSte using fabric method + impl = self.instantiate_learnedround_float_to_int_impl(impl, weights, value) + impl.train(training) + print(impl.value) + out = impl(weights) + # The FP values and its quantized values must differ by at most +/- 1 + assert torch.all(torch.abs(out - weights) <= 1) + if not isinstance(impl.learned_round_impl, LearnedRoundIdentity): + if training: + # Soft quantization. All values are at most distant +/- 1 from the nearest integer + assert torch.all(torch.abs(out - torch.round(out)) <= 1) + else: + # Hard quantization. All values are integers + assert torch.allclose(out, torch.round(out)) + else: + # All values should be integers for LearnedRoundIdentity + assert torch.allclose(out, torch.round(out)) + + @given( + learned_round_zeta=floats(min_value=0.0, max_value=3.0), + learned_round_gamma=floats(min_value=-3.0, max_value=-0.05), + value=floats(min_value=-5.0, max_value=5.0), + ) + def test_learnedround_float_to_int_impl_hard_sigmoid( + self, learned_round_zeta, learned_round_gamma, value): + value = torch.tensor([value], dtype=torch.float32) + weight = torch.zeros_like(value) + # Initialise learned round script module + learned_round_hard_sigmoid = LearnedRoundHardSigmoid( + learned_round_zeta=learned_round_zeta, + learned_round_gamma=learned_round_gamma, + ) + learned_round_hard_sigmoid.train(False) + value_eval = learned_round_hard_sigmoid(value) + learned_round_hard_sigmoid.train(True) + value_train = learned_round_hard_sigmoid(value) + + out_eval = weight + value_eval + out_train = weight + (value_train > 0.5) + + assert torch.allclose(out_eval, out_train) + @pytest_cases.fixture() @pytest_cases.parametrize('impl', LEARNEDROUND_IMPL) def learnedround_float_to_int_impl(self, impl): @@ -30,29 +101,26 @@ def learnedround_float_to_int_impl(self, impl): impl = LearnedRoundSte(impl, torch.full(sample_weight.shape, 0.)) # Simulate learned parameter - impl.value.data = torch.randn_like(impl.value) - return impl, sample_weight - - @pytest_cases.parametrize('training', [True, False]) - def test_learnedround(self, learnedround_float_to_int_impl, training): - impl, sample_weight = learnedround_float_to_int_impl - impl.train(training) - - out = impl(sample_weight) - if training: - # Soft quantization. All values are at most distant +/- 1 from the nearest integer - assert torch.all(torch.abs(out - torch.round(out)) < 1) - else: - # Hard quantization. All values are integers - assert torch.allclose(out, torch.round(out)) + value = torch.randn_like(impl.value) + impl.value.data = value + return impl, sample_weight, value def test_learnedround_load_dict(self, learnedround_float_to_int_impl): config.IGNORE_MISSING_KEYS = True - impl, _ = learnedround_float_to_int_impl + impl, _, _ = learnedround_float_to_int_impl quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) try: quant_conv.load_state_dict(fp_conv.state_dict()) except RuntimeError as e: pytest.fail(str(e)) + + def test_learnedround_state_dict(self, learnedround_float_to_int_impl): + impl, _, value = learnedround_float_to_int_impl + state_dict = impl.state_dict() + + # Verify that the state dict contains the entry corresponding to the + # learnable round parameter. + assert len(state_dict.keys()) == 1 + assert torch.allclose(state_dict["value"], value) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index 1a9157214..c3fd6a82a 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -174,14 +174,18 @@ def float_tensor_random_size_st( @st.composite def two_float_tensor_random_shape_st( - draw, min_dims=1, max_dims=4, max_size=3, width=FP32_BIT_WIDTH): + draw, min_dims=1, max_dims=4, max_size=3, min_val=None, max_val=None, width=FP32_BIT_WIDTH): """ Generate a tuple of float tensors of the same random shape. """ shape = draw(random_tensor_shape_st(min_dims, max_dims, max_size)) size = reduce(mul, shape, 1) - float_list1 = draw(st.lists(float_st(width=width), min_size=size, max_size=size)) - float_list2 = draw(st.lists(float_st(width=width), min_size=size, max_size=size)) + float_list1 = draw( + st.lists( + float_st(min_val=min_val, max_val=max_val, width=width), min_size=size, max_size=size)) + float_list2 = draw( + st.lists( + float_st(min_val=min_val, max_val=max_val, width=width), min_size=size, max_size=size)) tensor1 = torch.tensor(float_list1).view(shape) tensor2 = torch.tensor(float_list2).view(shape) return tensor1, tensor2 diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py new file mode 100644 index 000000000..2aca86c82 --- /dev/null +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -0,0 +1,179 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from itertools import product + +from packaging.version import parse +import pytest +import pytest_cases +import torch +from torch.nn import Parameter +from torch.optim.lr_scheduler import LinearLR + +from brevitas import torch_version +from tests.conftest import SEED +from tests.marker import requires_pt_ge + +torch.manual_seed(SEED) + +REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) +REFERENCE_WEIGHTS_GRAD = torch.tensor([[1.0023, 0.000, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) +REFERENCE_WEIGHTS_SIGN_GRAD = torch.tensor([[1.0000, 0.0000, 1.0000], [-1.0000, -1.0000, -1.0000], + [1.0000, -1.0000, -1.0000]]) + +OPTIMIZER_KWARGS = [{}, {"maximize": True}, {"lr": 1e-2}, {"lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float16, torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestOptimSignSGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("lr", [0.1]) + @requires_pt_ge('2.1') # TODO: revisit this + def test_sign_sgd_single_update(self, device, dtype, lr): + from brevitas.optim.sign_sgd import SignSGD + + # Initialize weights and grads + weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) + # Initialize tensors to compute expected result + initial_weights = REFERENCE_WEIGHTS.to(device=device, dtype=dtype, copy=True) + weight_grad = REFERENCE_WEIGHTS_GRAD.to(device=device, dtype=dtype) + weight_sign_grad = REFERENCE_WEIGHTS_SIGN_GRAD.to(device=device, dtype=dtype) + + optimizer = SignSGD([weights], lr=lr) + + # Perform a SignSGD update + optimizer.zero_grad() + weights.grad = weight_grad + optimizer.step() + + assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + @requires_pt_ge('2.1') + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD + + # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half + if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): + pytest.xfail( + "PyTorch versions previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half" + ) + + optim_cls = SignSGD + weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) + bias = Parameter(torch.randn((10), device=device, dtype=dtype)) + input = torch.randn(5, device=device, dtype=dtype) + + optimizer = optim_cls([weight, bias], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(input) + bias).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + print(bias) + if scheduler is not None: + scheduler.step() + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value + + @pytest.mark.skipif( + torch.cuda.device_count() <= 1, reason="At least two GPUs are required for this test.") + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + @pytest_cases.parametrize("dtype", [torch.float16, torch.float32]) + @requires_pt_ge('2.1') + def test_forloop_goes_right_direction_multigpu( + self, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD + optim_cls = SignSGD + # Learnable parameters + weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) + bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) + input = torch.randn(5, device="cuda:0", dtype=dtype) + + optimizer = optim_cls([weight, bias], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(input).cuda(1) + bias).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + + if scheduler is not None: + scheduler.step() + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py new file mode 100644 index 000000000..751d7790d --- /dev/null +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -0,0 +1,335 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Union + +from accelerate.utils.operations import send_to_device +from hypothesis import given +import pytest +import pytest_cases +from pytest_cases import fixture +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +from brevitas import config +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import LearnedRoundImplType +import brevitas.nn as qnn +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +from brevitas.quant_tensor.base_quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks +from brevitas_examples.common.learned_round.learned_round_optimizer import save_inputs_output + +config.IGNORE_MISSING_KEYS = True + + +class QuantBlock(nn.Module): + + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = qnn.QuantLinear(in_features=in_features, out_features=hidden_dim) + self.layer2 = qnn.QuantLinear(in_features=hidden_dim, out_features=out_features) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) + + +class TestQuantModel(nn.Module): + + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = QuantBlock( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x, block1_kwarg=0., **kwargs) + out = self.hidden_mlp(out, block2_kwarg=0., **kwargs) + return self.out_proj_mlp(out, block3_kwarg=0., **kwargs) + + +class Block(nn.Module): + + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = nn.Linear(in_features=in_features, out_features=hidden_dim) + self.layer2 = nn.Linear(in_features=hidden_dim, out_features=out_features) + self.relu = F.relu + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) + + +class TestModel(nn.Module): + + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = Block( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x, block1_kwarg=0., **kwargs) + out = self.hidden_mlp(out, block2_kwarg=0., **kwargs) + return self.out_proj_mlp(out, block3_kwarg=0., **kwargs) + + +class TestDataset(Dataset): + + def __init__(self): + self.data = [ + { + "x": torch.tensor([1.0, 2.0]), + "tensor": torch.tensor([0.0]), + "bool": True, + "float": 0.0,},] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class TestCache: + + def __init__(self) -> None: + self.args = [] + self.kwargs = [] + self.output = [] + + def __len__(self) -> int: + return len(self.args) + + def store_inputs(self, args: Any, kwargs: Any) -> None: + self.args.append(args) + self.kwargs.append(kwargs) + + def store_output(self, output: Any) -> None: + self.output.append(output) + + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + pass + + def initialize_cache(self) -> None: + pass + + def clear_cache(self) -> None: + pass + + def reset_cache(self) -> None: + pass + + def cache_to_dataset(self) -> Dataset: + pass + + def collate_fn(self, batch: Any) -> Any: + pass + + +class TestLearnedRound: + + @fixture + def quant_model(): + return TestQuantModel(in_features=2, out_features=1, hidden_dim=4) + + @fixture + def model(): + return TestModel(in_features=2, out_features=1, hidden_dim=4) + + @fixture + def data_loader(): + return DataLoader(TestDataset(), batch_size=1, shuffle=False) + + def test_get_blocks(self, quant_model: nn.Module): + + def _is_block(module: nn.Module, module_name: str) -> bool: + return module_name in ["hidden_mlp"] + + expected_blocks = [quant_model.hidden_mlp] + blocks = get_blocks(quant_model, _is_block) + + assert expected_blocks == blocks + + def test_get_layers(self, quant_model: nn.Module): + + def _is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + + expected_layers = [ + quant_model.in_proj_mlp.layer1, + quant_model.in_proj_mlp.layer2, + quant_model.hidden_mlp.layer1, + quant_model.hidden_mlp.layer2, + quant_model.out_proj_mlp.layer1, + quant_model.out_proj_mlp.layer2] + layers = get_blocks(quant_model, _is_layer) + + assert expected_layers == layers + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + @pytest.mark.parametrize("store_inputs", [True, False]) + @pytest.mark.parametrize("store_output", [True, False]) + @pytest.mark.parametrize("keep_gpu", [True, False]) + @pytest.mark.parametrize("disable_quant", [True, False]) + def test_save_inp_out_data( + self, model, quant_model, data_loader, store_inputs, store_output, keep_gpu, + disable_quant): + + def _compare_tensors(cache_tensor, gt_tensor, disable_quant, keep_gpu): + # The elements should be of the same type + assert isinstance(cache_tensor, torch.Tensor) if disable_quant else isinstance( + cache_tensor, QuantTensor) + # If they are QuantTensors extract their data + if isinstance(cache_tensor, QuantTensor): + cache_tensor, gt_tensor = cache_tensor.value, gt_tensor.value + # Verify that the tensor is in GPU if keep_gpu=True + assert keep_gpu == cache_tensor.is_cuda + # Make sure tensors are in the same device before comparison + cache_tensor = cache_tensor.cpu() + gt_tensor = gt_tensor.cpu() + # Check that their contents match + assert torch.allclose(cache_tensor, gt_tensor) + # Return True to be used in the helper + return True + + def model_forward(model, inputs): + device = next(model.parameters()).device + inputs = send_to_device(inputs, device) + model(**inputs) + + # Make sure that the quant and FP models share the same weights + quant_model.load_state_dict(model.state_dict()) + + # Prepare models + model.eval() + model = model.cuda() + quant_model.eval() + quant_model = quant_model.cuda() + + device = next(quant_model.parameters()).device + + # Compute ground-truth inputs/outputs + fp_args, fp_kwargs, fp_outs = [], [], [] + quant_args, quant_kwargs, quant_outs = [], [], [] + # Compute ground truths inputs and outputs + with torch.no_grad(): + for inputs in data_loader: + inputs = send_to_device(inputs, device) + kwargs = {k: v for k, v in inputs.items() if k != "x"} + # Compute quant inputs to module + quant_arg = quant_model.in_proj_mlp(**inputs) + quant_kwarg = {"block2_kwarg": 0.0, **kwargs} + # Compute quant outputs of module + quant_out = quant_model.hidden_mlp(quant_arg, **quant_kwarg) + + quant_args.append((quant_arg,)) + quant_kwargs.append(quant_kwarg) + quant_outs.append(quant_out) + + # Compute quant inputs to module + fp_arg = model.in_proj_mlp(**inputs) + fp_kwarg = {"block2_kwarg": 0.0, **kwargs} + # Compute quant outputs of module + fp_out = model.hidden_mlp(fp_arg, **fp_kwarg) + + fp_args.append((fp_arg,)) + fp_kwargs.append(fp_kwarg) + fp_outs.append(fp_out) + + # Prepare to capture inputs/outputs using DataSaverHook + cache = TestCache() + + # Retrieve module from quant_model + module = quant_model.hidden_mlp + + # Make call to save inputs/outputs + save_inputs_output( + model=quant_model, + model_forward=model_forward, + module=module, + dataloader=data_loader, + cache=cache, + store_inputs=store_inputs, + store_output=store_output, + keep_gpu=keep_gpu, + disable_quant=disable_quant, + ) + + # Verify that the lengths of the lists match + if store_inputs: + assert len(cache.args) == len(fp_args) and len(cache.kwargs) == len(fp_kwargs) + else: + assert len(cache.args) == 0 and len(cache.kwargs) == 0 + + if store_output: + assert len(cache.output) == len(fp_outs) + else: + assert len(cache.output) == 0 + + # Verify that the arguments are the same + for cache_arg, gt_arg in zip(cache.args, fp_args if disable_quant else quant_args): + _compare_tensors(cache_arg[0], gt_arg[0], disable_quant, keep_gpu) + + for cache_kwarg, gt_kwarg in zip(cache.kwargs, fp_kwargs if disable_quant else quant_kwargs): + # Compare the contents within each dictionary + same_contents = all(( + torch.allclose(cache_kwarg.get(gt_kwarg_k, None).cpu(), gt_kwarg_v.cpu( + )) if isinstance(gt_kwarg_v, torch.Tensor) else cache_kwarg.get(gt_kwarg_k, None) == + gt_kwarg_v) for gt_kwarg_k, + gt_kwarg_v in gt_kwarg.items()) + # Verify that the dictionaries have the same keys and content + assert set(cache_kwarg.keys()) == set(gt_kwarg.keys()) and same_contents + + # Verify that the outputs are the same + for cache_output, gt_output in zip(cache.output, fp_outs if disable_quant else quant_outs): + _compare_tensors(cache_output, gt_output, disable_quant, keep_gpu) + + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) + def test_insert_learned_round_quantizers(self, quant_model, learned_round): + block = quant_model.in_proj_mlp + learned_round.insert_learned_round_quantizers(block) + + for module in block.modules(): + if hasattr(module, "weight_quant"): + assert module.weight_quant.rounding_mode == "LEARNED_ROUND" + assert isinstance( + module.weight_quant.tensor_quant.int_quant.float_to_int_impl, LearnedRoundSte) + + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) + @pytest.mark.parametrize( + "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), + (["in_proj_mlp", "out_proj_mlp"], 4)]) + def test_return_learned_round_quantizers( + self, quant_model, learned_round, block_strs, num_round_modules): + # Inject quantizers in quant model + for block_str in block_strs: + block = getattr(quant_model, block_str) + learned_round.insert_learned_round_quantizers(block) + learned_round_modules = learned_round.return_learned_round_quantizers(quant_model) + assert len(learned_round_modules) == num_round_modules diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index e59973b95..60dd33ac2 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -525,3 +525,51 @@ def test_small_models_torch_export(caplog, torch_export_args): filepath = args.export_prefix + ".pt" torchscript_model = torch.jit.load(filepath) os.remove(filepath) + + +@pytest_cases.fixture( + ids=[ + "llama", + "mistral",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "learned_round": "linear_round", + "learned_round_iters": 1, + "gpxq_block_name": "model.layers", + "float_ppl": 33238.8984375, + "quant_ppl": 33252.21484375}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "learned_round": "linear_round", + "learned_round_iters": 1, + "gpxq_block_name": "model.layers", + "float_ppl": 31275.958984375, + "quant_ppl": 31337.4921875},]) +def learned_round_ppl_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = learned_round_ppl_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"