From dfb85791a5a10f4647fc7414e88807623a1f6323 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 16 Oct 2024 15:01:38 +0100 Subject: [PATCH 01/48] AutoRound standalone implementation --- .../core/function_wrapper/auto_round.py | 48 ++ src/brevitas/optim/sign_sgd.py | 448 ++++++++++++++++++ .../imagenet_classification/ptq/ptq_common.py | 35 ++ tests/brevitas/core/test_float_to_int.py | 42 ++ tests/brevitas/optim/test_sign_sgd.py | 271 +++++++++++ 5 files changed, 844 insertions(+) create mode 100644 src/brevitas/core/function_wrapper/auto_round.py create mode 100644 src/brevitas/optim/sign_sgd.py create mode 100644 tests/brevitas/optim/test_sign_sgd.py diff --git a/src/brevitas/core/function_wrapper/auto_round.py b/src/brevitas/core/function_wrapper/auto_round.py new file mode 100644 index 000000000..7e7f40d6b --- /dev/null +++ b/src/brevitas/core/function_wrapper/auto_round.py @@ -0,0 +1,48 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +""" +Implementation of AutoRound +""" + +from typing import Optional + +import torch + +import brevitas +from brevitas import config +from brevitas.core.utils import SliceTensor +from brevitas.function.ops_ste import round_ste + + +class AutoRoundSte(brevitas.jit.ScriptModule): + """ + This Module implements AutoRound representation, where each weight has a learnable parameter + that decides if "ceil" or "floor" rounding type has to be used. + """ + + def __init__( + self, + learned_round_init: torch.Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super(AutoRoundSte, self).__init__() + learned_round_init = learned_round_init.to(device=device, dtype=dtype) + self.tensor_slicer = SliceTensor() + self.value = torch.nn.Parameter(learned_round_init) + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor) -> torch.Tensor: + # p should be between [-0.5, 0.5], so this learnable parameter decides whether to "ceil" or "floor" + p = self.value + p = self.tensor_slicer(p) + return round_ste(x + p.to(x.dtype)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + super(AutoRoundSte, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + value_key = prefix + 'value' + if config.IGNORE_MISSING_KEYS and value_key in missing_keys: + missing_keys.remove(value_key) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py new file mode 100644 index 000000000..dd5d62365 --- /dev/null +++ b/src/brevitas/optim/sign_sgd.py @@ -0,0 +1,448 @@ +# mypy: allow-untyped-defs +from typing import List, Optional + +import torch +from torch import Tensor +from torch.optim.optimizer import _default_to_fused_or_foreach +from torch.optim.optimizer import _differentiable_doc +from torch.optim.optimizer import _foreach_doc +from torch.optim.optimizer import _fused_doc +from torch.optim.optimizer import _maximize_doc +from torch.optim.optimizer import _use_grad_for_differentiable +from torch.optim.optimizer import DeviceDict +from torch.optim.optimizer import Optimizer +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices + +__all__ = ["SignSGD", "sign_sgd"] + + +class SignSGD(Optimizer): + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + fused: Optional[bool] = None, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all(p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"]): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}.") + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params: List[Tensor] = [] + grads: List[Tensor] = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + + has_sparse_grad = self._init_group(group, params, grads, momentum_buffer_list) + + sign_sgd( + params, + grads, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + if group["momentum"] != 0: + # update momentum_buffers in state + for p, momentum_buffer in zip(params, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +SignSGD.__doc__ = ( + r"""Implements stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + """ + rf""" + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + {_maximize_doc} + {_foreach_doc} + {_differentiable_doc} + {_fused_doc} + """ + r""" + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + + Moreover, the initial value of the momentum buffer is set to the + gradient value at the first step. This is in contrast to some other + frameworks that initialize it to all zeros. + + """) + + +def sign_sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = False, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + fused, foreach = _default_to_fused_or_foreach( + params, differentiable=False, use_fused=False + ) + else: + foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + if fused and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with fused optimizers") + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sign_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sign_sgd + else: + func = _single_tensor_sign_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _single_tensor_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(grad).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + param.add_(torch.sign(grad), alpha=-lr) + + +def _multi_tensor_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +): + assert grad_scale is None and found_inf is None + + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], + with_indices=True # type: ignore[list-item] + ) + for ( + device_params, + device_grads, + device_momentum_buffer_list, + ), indices in grouped_tensors.values(): + device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) + + if maximize: + device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] + + if weight_decay != 0: + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(device_momentum_buffer_list[i]) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[ + indices[i]] = torch.clone(device_grads[i]).detach() + else: + buf = device_momentum_buffer_list[i] + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + # handle internal item() call if lr is a tensor + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): + grads_x_lr = torch._foreach_mul(torch._foreach_sign(device_grads), -lr) + torch._foreach_add_(device_params, grads_x_lr) + else: + torch._foreach_add_(device_params, torch._foreach_sign(device_grads), alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(torch.sign(device_grads[i]), alpha=-lr) + + +def _fused_sign_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + raise NotImplementedError("Fused SignSGD is not implemented.") diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 38ed85678..4433f3b5e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -3,6 +3,7 @@ from functools import partial import math +from warnings import warn import torch from tqdm import tqdm @@ -23,6 +24,7 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn +from brevitas.optim.sign_sgd import SignSGD from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -674,6 +676,39 @@ def apply_learned_round_learning( loss, rec_loss, round_loss, b)) +def apply_auto_round_learning( + model, dataloader, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): + # Add message in case the range can be surpassed + if iters * optimizer_lr > 0.5: + warn("It is possible that the weights are not rounded to their floor or ceil.") + + layers = [] + split_layers(model, layers) + print(f"Total Iterations per layer {iters}") + print(f"Number of layers {len(layers)}") + + for layer, layer_loss, learned_round_module in learned_round_iterator(layers, iters=iters): + optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) + _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) + all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) + max_size = len(all_fp_out) + pbar = tqdm(range(iters), desc='') + for i in pbar: + idx = torch.randint(0, max_size, (dataloader.batch_size,)) + quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] + layer.train() + + optimizer.zero_grad() + quant_out = layer(quant_inp) + loss, rec_loss, round_loss, b = layer_loss(quant_out, fp_out) + + loss.backward() + optimizer.step() + pbar.set_description( + "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( + loss, rec_loss, round_loss, b)) + + def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index ef635c279..1bcb79e47 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -6,6 +6,7 @@ import torch from brevitas import config +from brevitas.core.function_wrapper.auto_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte @@ -56,3 +57,44 @@ def test_learnedround_load_dict(self, learnedround_float_to_int_impl): quant_conv.load_state_dict(fp_conv.state_dict()) except RuntimeError as e: pytest.fail(str(e)) + + +class TestAutoRound(): + + @pytest_cases.fixture() + def autoround_float_to_int_impl(self): + sample_weight = torch.randn(OUT_CH, IN_CH, KERNEL_SIZE, KERNEL_SIZE) + impl = AutoRoundSte(torch.full(sample_weight.shape, 0.)) + + # Simulate learned parameter, values should be in the interval (-0.5, 0.5) + impl.value.data = torch.rand_like(impl.value) * 0.5 + return impl, sample_weight + + def test_autoround(self, autoround_float_to_int_impl): + impl, sample_weight = autoround_float_to_int_impl + + out = impl(sample_weight) + # Check that all values are integers + assert torch.allclose(out, torch.round(out)) + # Check that the values differ by at most 1 unit + assert torch.all(torch.abs(sample_weight - out) < 1) + + def test_autoround_load_dict(self, autoround_float_to_int_impl): + config.IGNORE_MISSING_KEYS = True + + impl, _ = autoround_float_to_int_impl + quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) + fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) + try: + quant_conv.load_state_dict(fp_conv.state_dict()) + except RuntimeError as e: + pytest.fail(str(e)) + + def test_autoround_edge_cases(self): + sample_weight = torch.tensor([-1.000, -0.500, 0.000, 0.500, 1.000]) + impl_data = torch.tensor([-0.500, 0.500, 0.000, -0.500, 0.500]) + impl = AutoRoundSte(impl_data) + + out = impl(sample_weight) + # Check that all values are integers + assert torch.allclose(out, torch.tensor([-2.000, 0.000, 0.000, 0.000, 2.000])) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py new file mode 100644 index 000000000..0bba58151 --- /dev/null +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -0,0 +1,271 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import math +import sys +from typing import List, Union +import unittest + +from hypothesis import given +import pytest +import pytest_cases +from pytest_cases import fixture +import torch +from torch.nn import Parameter +import torch.nn as nn +from torch.optim.lr_scheduler import ConstantLR +from torch.optim.lr_scheduler import ExponentialLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import PolynomialLR +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.optim.lr_scheduler import StepLR +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_device_type import onlyCUDA +from torch.testing._internal.common_device_type import tol +from torch.testing._internal.common_device_type import toleranceOverride +from torch.testing._internal.common_optimizers import DecorateInfo +from torch.testing._internal.common_optimizers import optim_error_inputs_func_sgd +from torch.testing._internal.common_optimizers import optim_inputs_func_sgd +from torch.testing._internal.common_optimizers import OptimizerErrorEnum +from torch.testing._internal.common_optimizers import OptimizerInfo +from torch.testing._internal.common_optimizers import optims +from torch.testing._internal.common_optimizers import skipIfTorchDynamo +from torch.testing._internal.common_utils import markDynamoStrictTest +from torch.testing._internal.common_utils import parametrize +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import skipIfTorchDynamo +from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO +from torch.testing._internal.common_utils import TestCase + +from brevitas.graph.calibrate import bias_correction_mode +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.graph.calibrate import restore_return_quant_tensor +from brevitas.inject.enum import RestrictValueType +import brevitas.nn as qnn +from brevitas.optim.sign_sgd import SignSGD +from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat +from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant_tensor import QuantTensor +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue +from tests.brevitas.hyp_helper import float_tensor_random_size_st +from tests.conftest import SEED + +torch.manual_seed(SEED) + +REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) +REFERENCE_WEIGHTS_GRAD = torch.tensor([[1.0023, 0.000, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) +REFERENCE_WEIGHTS_SIGN_GRAD = torch.tensor([[1.0000, 0.0000, 1.0000], [-1.0000, -1.0000, -1.0000], + [1.0000, -1.0000, -1.0000]]) + +optim_db: List[OptimizerInfo] = [ + OptimizerInfo( + SignSGD, + optim_inputs_func=optim_inputs_func_sgd, + scheduler_inputs=( + [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], + [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)], + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4),], + [ + lambda opt: StepLR(opt, gamma=0.99, step_size=10), + lambda opt: ExponentialLR(opt, gamma=0.99), + lambda opt: ReduceLROnPlateau(opt),], + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], + [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt),], + ), + optim_error_inputs_func=optim_error_inputs_func_sgd, + supported_impls=("foreach", "differentiable", "fused"), + supports_sparse=True, + metadata_for_sparse=( + { + "lr": 4.8e-3, + "maximize": False, + "momentum": 0, + "nesterov": False, + "weight_decay": 0,}, + [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], + ), + supports_fused_on=( + "cpu", + "cuda", + "mps", + ), + skips=(), + ),] + + +@markDynamoStrictTest +class TestOptimSignSGD(TestCase): + + @parametrize("lr", [0.1]) + @optims(optim_db, dtypes=[torch.float32]) + def test_sign_sgd_update(self, device, dtype, optim_info, lr): + optim_cls = optim_info.optim_cls + # Initialize weights and grads + weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) + # Initialize tensors to compute expected result + initial_weights = REFERENCE_WEIGHTS.to(device=device, dtype=dtype, copy=True) + weight_grad = REFERENCE_WEIGHTS_GRAD.to(device=device, dtype=dtype) + weight_sign_grad = REFERENCE_WEIGHTS_SIGN_GRAD.to(device=device, dtype=dtype) + + optimizer = optim_cls([weights], lr=lr) + + # Perform a SignSGD update + optimizer.zero_grad() + weights.grad = weight_grad + optimizer.step() + + assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) + + @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None], + dtypes=[torch.float32]) + def test_errors(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype) + + for error_input in error_inputs: + optim_input = error_input.optimizer_error_input + params, kwargs = optim_input.params, optim_input.kwargs + if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR: + if issubclass(error_input.error_type, Warning): + with self.assertWarnsRegex(error_input.error_type, error_input.error_regex): + optim_cls(params, **kwargs) + else: + with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): + optim_cls(params, **kwargs) + elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR: + optim = optim_cls(params, **kwargs) + if issubclass(error_input.error_type, Warning): + with self.assertWarnsRegex(error_input.error_type, error_input.error_regex): + optim.step() + else: + with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): + optim.step() + else: + raise NotImplementedError(f"Unknown error type {error_input.error_on}") + + @parametrize("contiguous", [True, False]) + @parametrize("with_lrsched", [True, False]) + @optims(optim_db, dtypes=[torch.float32]) + def test_forloop_goes_right_direction( + self, device, dtype, optim_info, contiguous, with_lrsched): + optim_cls = optim_info.optim_cls + schedulers_constructors = (optim_info.scheduler_inputs if with_lrsched else [None]) + + for schedulers_constructor in schedulers_constructors: + # with tensor LR we need fresh inputs for each scheduler + # or mutating it will carry across iters + optim_inputs = optim_info.optim_inputs_func(device=device) + for optim_input in optim_inputs: + if "foreach" in optim_info.supported_impls: + optim_input.kwargs["foreach"] = False # force forloop + if contiguous: + weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) + bias = Parameter(torch.randn((10), device=device, dtype=dtype)) + else: + weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]) + bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0]) + input = torch.randn(5, device=device, dtype=dtype) + + optimizer = optim_cls([weight, bias], **optim_input.kwargs) + schedulers = [ + s(optimizer) + for s in (schedulers_constructor if schedulers_constructor else [])] + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(input) + bias).pow(2).sum() + loss.backward() + if optim_info.only_supports_sparse_grads: + # For this test, we naively convert the Tensor layout, which we know does + # NOT represent the expected use case for optims like SparseAdam! + weight.grad = weight.grad.to_sparse() + bias.grad = bias.grad.to_sparse() + return loss + + initial_value = closure().item() + for _ in range(20): + if optim_info.step_requires_closure: + loss = optimizer.step(closure) + else: + loss = closure() + optimizer.step() + + for scheduler in schedulers: + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(loss) + else: + scheduler.step() + + if optim_input.kwargs.get("maximize", False): + self.assertGreater(closure().item(), initial_value) + else: + self.assertLess(closure().item(), initial_value) + + @onlyCUDA + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") + @parametrize("with_lrsched", [True, False]) + @optims(optim_db, dtypes=[torch.float32]) + def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info, with_lrsched): + optim_cls = optim_info.optim_cls + schedulers_constructors = (optim_info.scheduler_inputs if with_lrsched else [None]) + for schedulers_constructor in schedulers_constructors: + # We need a fresh set of inputs if we have a tensor LR + # to not carry mutations across iterations. + optim_inputs = optim_info.optim_inputs_func(device=device) + for optim_input in optim_inputs: + if "foreach" in optim_info.supported_impls: + optim_input.kwargs["foreach"] = False # force forloop + + weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) + bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) + inpt = torch.randn(5, device="cuda:0", dtype=dtype) + + optimizer = optim_cls([weight, bias], **optim_input.kwargs) + schedulers = [ + s(optimizer) + for s in (schedulers_constructor if schedulers_constructor else [])] + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum() + loss.backward() + if optim_info.only_supports_sparse_grads: + # For this test, we naively convert the Tensor layout, which we know does + # NOT represent the expected use case for optims like SparseAdam! + weight.grad = weight.grad.to_sparse() + bias.grad = bias.grad.to_sparse() + return loss + + initial_value = closure().item() + for _ in range(20): + loss = optimizer.step(closure) + for scheduler in schedulers: + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(loss) + else: + scheduler.step() + + if optim_input.kwargs.get("maximize", False): + self.assertGreater(closure().item(), initial_value) + else: + self.assertLess(closure().item(), initial_value) + + +instantiate_device_type_tests(TestOptimSignSGD, globals(), allow_mps=True) + +if __name__ == "__main__": + run_tests() From 0a132c9939b0d95585e0038540e414c9404f8f95 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Sun, 20 Oct 2024 23:18:36 +0100 Subject: [PATCH 02/48] Initial implementation --- .../core/function_wrapper/__init__.py | 1 + src/brevitas/inject/enum.py | 1 + src/brevitas/quant/solver/common.py | 2 + src/brevitas/utils/quant_utils.py | 2 + .../ptq/learned_round_utils.py | 43 ++++ .../imagenet_classification/ptq/ptq_common.py | 226 ++++++++++++++++-- .../ptq/ptq_evaluate.py | 65 +++-- tests/brevitas_examples/test_imagenet.py | 31 +++ 8 files changed, 331 insertions(+), 40 deletions(-) create mode 100644 tests/brevitas_examples/test_imagenet.py diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index 3b3e5428b..4929026f3 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from .auto_round import AutoRoundSte from .clamp import ClampMin from .clamp import FloatClamp from .clamp import ScalarClamp diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index 129a55252..67122dddb 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -46,6 +46,7 @@ class FloatToIntImplType(AutoName): DPU = auto() LEARNED_ROUND = auto() STOCHASTIC_ROUND = auto() + AUTO_ROUND = auto() class LearnedRoundImplType(AutoName): diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index a4930e43d..509599764 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -49,6 +49,8 @@ def solve_float_to_int_impl_from_enum(impl_type): return LearnedRoundSte elif impl_type == FloatToIntImplType.STOCHASTIC_ROUND: return StochasticRoundSte + elif impl_type == FloatToIntImplType.AUTO_ROUND: + return AutoRoundSte else: raise Exception(f"{impl_type} not recognized.") diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 62290b1de..6ba0ebf76 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -214,6 +214,8 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.DPU elif isinstance(module, LearnedRoundSte): return FloatToIntImplType.LEARNED_ROUND + elif isinstance(module, AutoRoundSte): + return FloatToIntImplType.AUTO_ROUND elif isinstance(module, StochasticRoundSte): if module.deterministic_inference: return FloatToIntImplType.ROUND diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 7a5a283ea..2f6df217a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,11 +26,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from typing import List, Optional + import numpy as np import torch +from torch import nn import torch.nn.functional as F from brevitas import config +from brevitas.core.function_wrapper.auto_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization @@ -161,6 +165,45 @@ def split_layers(model, layers): split_layers(module, layers) +def insert_auto_round_quantizer_block_layers(block: nn.Module) -> None: + # Iterate over the layers and insert AutoRound quantizer in QuantWBIOL layers + for module in block.modules(): + if isinstance(module, QuantWBIOL): + value = torch.zeros_like(module.weight.data) + module.weight_quant.quant_injector = module.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, + learned_round_init=value, + ) + module.weight_quant.init_tensor_quant(preserve_state_dict=True) + + +def find_round_modules_in_block(block: nn.Module, + class_round_quant=AutoRoundSte) -> List[nn.Module]: + round_modules = [] + for module in block.modules(): + if isinstance(module, class_round_quant): + round_modules.append(module) + return round_modules + + +def auto_round_iterator(blocks: List[nn.Module]): + for block in blocks: + # Insert AutoRound quantizer in the block layers + insert_auto_round_quantizer_block_layers(block) + # Freeze block parameters + for params in block.parameters(): + params.requires_grad = False + # Retrieve learned round modules + learned_round_modules = find_round_modules_in_block(block, AutoRoundSte) + # Enable gradient tracking in learned round modules + for round_module in learned_round_modules: + for params in round_module.parameters(): + params.requires_grad = True + + yield block, learned_round_modules + block.eval() + + def learned_round_iterator(layers, iters=1000): for layer in layers: insert_learned_round_quantizer(layer) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 4433f3b5e..d2b2ab5b1 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,10 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import partial +import itertools import math +from typing import Callable, List from warnings import warn +from accelerate.utils.operations import send_to_device import torch +from torch import nn +import torch.backends.cudnn as cudnn from tqdm import tqdm from brevitas.core.function_wrapper.shape import OverBatchOverTensorView @@ -13,12 +18,16 @@ from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import norm_correction_mode +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gpfq import GPFQv2 from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode +from brevitas.graph.gpxq import StopFwdException from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml @@ -74,6 +83,7 @@ from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import auto_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers @@ -676,37 +686,203 @@ def apply_learned_round_learning( loss, rec_loss, round_loss, b)) -def apply_auto_round_learning( - model, dataloader, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): +# TODO: Replace by an actual function. Remove +def _is_resnet_block(module: nn.Module, module_name: str) -> bool: + import re + return (re.search(r"layer\d+", module_name) is not None) + + +def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], + bool]) -> List[nn.Module]: + blocks = [] + + # Iterating over .modules() might have been more readable but + # with this recursive implementation, once a block is reached, + # its subtree of modules is not expanded. + def _get_blocks(module: nn.Module): + for module_name, module_child in module.named_children(): + if block_check_fn(module_child, module_name): + blocks.append(module_child) + else: + _get_blocks(module_child) + + # Run recursive function that updates the list blocks + _get_blocks(model) + return blocks + + +def apply_auto_round_learning_debug( + model, dataloader, device, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): # Add message in case the range can be surpassed - if iters * optimizer_lr > 0.5: + if optimizer_class == SignSGD and iters * optimizer_lr > 0.5: warn("It is possible that the weights are not rounded to their floor or ceil.") - layers = [] - split_layers(model, layers) - print(f"Total Iterations per layer {iters}") - print(f"Number of layers {len(layers)}") + blocks = get_blocks(model, _is_resnet_block) + print(f"Total Iterations per block {iters}") + print(f"Number of blocks {len(blocks)}") + + # The following code reuses most of the code in auto_learned_round_learning. This implementation + # requires a forward pass through the whole net to capture inputs/outputs of a single block, + # while for auto_learned_round_learning_efficient a single forward is required. + + # Iterate over the blocks, keeping track of outputs to pass them to next block + for block_idx, (block, learned_round_modules) in enumerate(auto_round_iterator(blocks)): + # Instantiate optimizer with the parameters of the learned round modules + optimizer = optimizer_class( + itertools.chain( + *[ + learned_round_module.parameters() + for learned_round_module in learned_round_modules]), + lr=optimizer_lr) + # Use MSE loss to measure the discrepancy between quantised and unquantised outputs + mse_loss_fn = nn.MSELoss() + # Save inputs and outputs + _, all_fp_out = save_inp_out_data(model, block, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) + all_quant_inp, _ = save_inp_out_data(model, block, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - for layer, layer_loss, learned_round_module in learned_round_iterator(layers, iters=iters): - optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) - _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) max_size = len(all_fp_out) - pbar = tqdm(range(iters), desc='') - for i in pbar: - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - layer.train() - - optimizer.zero_grad() - quant_out = layer(quant_inp) - loss, rec_loss, round_loss, b = layer_loss(quant_out, fp_out) + with tqdm(total=iters) as pbar: + for _ in range(iters): + idx = torch.randint(0, max_size, (dataloader.batch_size,)) + quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] + block.train() + + optimizer.zero_grad() + quant_out = block(quant_inp) + loss = mse_loss_fn(quant_out, fp_out) + + loss.backward() + optimizer.step() + # Update progress bar + pbar.set_description( + "block = {:d}/{:d}, loss = {:.4f}".format(block_idx + 1, len(blocks), loss)) + pbar.update(1) + + +# TODO: Investigate performance drop +def apply_auto_round_learning_efficient( + model, dataloader, device, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): + # Add message in case the range can be surpassed + if optimizer_class == SignSGD and iters * optimizer_lr > 0.5: + warn("It is possible that the weights are not rounded to their floor or ceil.") - loss.backward() - optimizer.step() - pbar.set_description( - "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( - loss, rec_loss, round_loss, b)) + blocks = get_blocks(model, _is_resnet_block) + print(f"Total Iterations per block {iters}") + print(f"Number of blocks {len(blocks)}") + + # NOTE: Note that we are storing the output for each batch, thus + # resulting in a memory cost proportional to the number of samples + # in the calibration set. It might be desirable to consider + # alternatives that enable rounding in a batch-wise manner. + cached_args, cached_kwargs = [], [] + + # Method to intercept the input to the first block + def intercept_input(module: nn.Module, args, kwargs): + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + cached_args.append(args) + cached_kwargs.append(kwargs) + raise StopFwdException + + # Method to intercept output of each block + def intercept_output(module: nn.Module, args, kwargs, output, cache: List): + if isinstance(output, tuple): + output = output[0] + output = send_to_device(output, 'cpu') + cache.append((output,)) + raise StopFwdException + + # Assumptions of the following code: + # 1: The order in the list blocks corresponds to the order in which + # each block forward is run, when the model forward is executed. + # 2: The input to each block is the output of the previous block. + + # Disable quantisation for retrieving FP inputs + # TODO: Check value for call_act_quantizer_impl + toggle_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=False) + toggle_quant_inference.disable_param_quantization(model, is_training=True) + toggle_quant_inference.disable_bias_quantization(model, is_training=True) + return_model_quant_tensor_state = disable_return_quant_tensor(model) + + # Capture inputs to the first block and store them in cached_args, cached_kwargs + hook = blocks[0].register_forward_pre_hook(intercept_input, with_kwargs=True) + with torch.no_grad(): + for img_batch, _ in dataloader: + try: + img_batch = img_batch.to(device) + model(img_batch) + except StopFwdException: + pass + hook.remove() + + # Enable quantisation for consistency + toggle_quant_inference.enable_param_quantization(model, is_training=True) + toggle_quant_inference.enable_bias_quantization(model, is_training=True) + restore_return_quant_tensor(model, return_model_quant_tensor_state) + + # Iterate over the blocks, keeping track of outputs to pass them to next block + for block_idx, (block, learned_round_modules) in enumerate(auto_round_iterator(blocks)): + # Instantiate optimizer with the parameters of the learned round modules + optimizer = optimizer_class( + itertools.chain( + *[ + learned_round_module.parameters() + for learned_round_module in learned_round_modules]), + lr=optimizer_lr) + # Use MSE loss to measure the discrepancy between quantised and unquantised outputs + mse_loss_fn = nn.MSELoss() + # Prevent needing to perform a deep copy + past_cached_args = cached_args + cached_args = [] + # Process each batch storing outputs + hook = block.register_forward_hook( + partial(intercept_output, cache=cached_args), with_kwargs=True) + # Disable quantisation to get FP outputs + toggle_quant_inference.disable_param_quantization(block, is_training=True) + toggle_quant_inference.disable_bias_quantization(block, is_training=True) + return_block_quant_tensor_state = disable_return_quant_tensor(block) + # Retrieve the FP outputs + with torch.no_grad(): + for args, kwargs in zip(past_cached_args, cached_kwargs): + try: + args = send_to_device(args, device) + kwargs = send_to_device(kwargs, device) + block(*args, **kwargs) + except StopFwdException: + pass + hook.remove() + # Enable quantisation to get Quant outputs + toggle_quant_inference.enable_param_quantization(block, is_training=True) + toggle_quant_inference.enable_bias_quantization(block, is_training=True) + restore_return_quant_tensor(block, return_block_quant_tensor_state) + + # Concatenate inputs and outputs along the batch dimension + past_cached_args_opt = tuple( + torch.cat(tensors, dim=0) for tensors in zip(*past_cached_args)) + cached_args_opt = tuple(torch.cat(tensors, dim=0) for tensors in zip(*cached_args)) + + # TODO: Verify with the implementation of AutoRound + with tqdm(total=iters) as pbar: + block.train() + for _ in range(iters): + # Subsample in calibration subset + idx = torch.randint(0, cached_args_opt[0].shape[0], (dataloader.batch_size,)) + fp_input, fp_out = past_cached_args_opt[0][idx], cached_args_opt[0][idx] + # Move tensor to appropiate devices + fp_input = send_to_device(fp_input, device) + fp_out = send_to_device(fp_out, device) + + quant_out = block(fp_input) + loss = mse_loss_fn(fp_out, quant_out) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Update progress bar + pbar.set_description( + "block = {:d}/{:d}, loss = {:.4f}".format(block_idx + 1, len(blocks), loss)) + pbar.update(1) def check_positive_int(*args): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 34bdfbc96..86e4f48fe 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -21,7 +21,11 @@ from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize +from brevitas.optim.sign_sgd import SignSGD from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_auto_round_learning +from brevitas_examples.imagenet_classification.ptq.ptq_common import \ + apply_auto_round_learning_efficient from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq @@ -158,6 +162,11 @@ def validate_args(args): default=20, type=int, help='Numbers of iterations for graph equalization (default: 20)') +parser.add_argument( + '--learned-round-type', + default='none', + choices=['none', 'ada_round', 'auto_round'], + help='Learned round type (default: none)') parser.add_argument( '--learned-round-iters', default=1000, @@ -250,6 +259,11 @@ def validate_args(args): help= 'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)' ) +parser.add_argument( + '--optimizer', + default='adam', + choices=['adam', 'sign_sgd'], + help='Optimizer to use with learnable rounding (default: %(default)s)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( @@ -264,7 +278,6 @@ def validate_args(args): 'gpxq-create-weight-orig', default=False, help='Maintain original weights for non-quant forward pass (default: disabled)') -add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') add_bool_arg( parser, @@ -321,7 +334,7 @@ def main(): f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" - f"{'learned_round_' if args.learned_round else ''}" + f"{'learned_round_type' if args.learned_round_type != "none" else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" f"{args.bias_bit_width}bias_" f"{args.weight_quant_granularity}_" @@ -343,8 +356,7 @@ def main(): f"GPTQ: {args.gptq} - " f"GPFQ: {args.gpfq} - " f"GPxQ Act Order: {args.gpxq_act_order} - " - f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - " - f"Learned Round: {args.learned_round} - " + f"Learned Round type: {args.learned_round_type} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " f"Weight scale factors type: {args.weight_quant_granularity} - " @@ -398,20 +410,25 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) elif args.target_backend == 'fx' or args.target_backend == 'layerwise': - model = preprocess_for_quantize( - model, - equalize_iters=args.graph_eq_iterations, - equalize_merge_bias=args.graph_eq_merge_bias, - merge_bn=args.merge_bn, - channel_splitting_ratio=args.channel_splitting_ratio, - channel_splitting_split_input=args.channel_splitting_split_input) + if args.learned_round_type != "auto_round": + model = preprocess_for_quantize( + model, + equalize_iters=args.graph_eq_iterations, + equalize_merge_bias=args.graph_eq_merge_bias, + merge_bn=args.merge_bn, + channel_splitting_ratio=args.channel_splitting_ratio, + channel_splitting_split_input=args.channel_splitting_split_input) else: raise RuntimeError(f"{args.target_backend} backend not supported.") + device = ( + torch.device(f"cuda:{args.gpu}") + if args.gpu is not None + else torch.device("cpu") + ) + model = model.to(device=device) # If available, use the selected GPU if args.gpu is not None: - torch.cuda.set_device(args.gpu) - model = model.cuda(args.gpu) cudnn.benchmark = False if args.act_equalization is not None: @@ -475,11 +492,29 @@ def main(): max_accumulator_bit_width=args.gpxq_accumulator_bit_width, max_accumulator_tile_size=args.gpxq_accumulator_tile_size) - if args.learned_round: - print("Applying Learned Round:") + if args.optimizer == "adam": + optimizer_class = torch.optim.Adam + elif args.optimizer == "sign_sgd": + optimizer_class = SignSGD + else: + raise ValueError(f"{args.optimizer} is not a valid optimizer.") + + if args.learned_round_type == "auto_round": + print("Applying Auto Round:") + apply_auto_round_learning( + quant_model, + calib_loader, + device, + optimizer_class=optimizer_class, + iters=args.learned_round_iters, + optimizer_lr=args.learned_round_lr) + + if args.learned_round_type == "ada_round": + print("Applying Learned Round (AdaRound):") apply_learned_round_learning( quant_model, calib_loader, + optimizer_class=optimizer_class, iters=args.learned_round_iters, optimizer_lr=args.learned_round_lr) diff --git a/tests/brevitas_examples/test_imagenet.py b/tests/brevitas_examples/test_imagenet.py new file mode 100644 index 000000000..4d7afdc7c --- /dev/null +++ b/tests/brevitas_examples/test_imagenet.py @@ -0,0 +1,31 @@ +from hypothesis import given +import pytest +import pytest_cases +from pytest_cases import fixture +import torch +import torch.nn as nn + +from brevitas_examples.imagenet_classification.ptq.utils import get_torchvision_model + +DTYPE = torch.float32 + + +class TestImageNet: + + @fixture + def model(): + # Get the model from torchvision + model = get_torchvision_model("resnet18") + model = model.to(DTYPE) + model.eval() + + return model + + def test_model_can_be_loaded(model): + print(f"The model class IS: {type(model)}") + assert False + + +if __name__ == "__main__": + # Run pytest on the current file + pytest.main(["-s", __file__]) From 701f8be23b2da9928cb37f94e41e28c01f6122a9 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 23 Oct 2024 17:09:51 +0100 Subject: [PATCH 03/48] Refactoring before removing legacy code --- .../ptq/learned_round_utils.py | 199 +++++++++++- .../imagenet_classification/ptq/ptq_common.py | 100 +++++- .../ptq/ptq_evaluate.py | 40 +++ .../test_learned_round_utils.py | 285 ++++++++++++++++++ 4 files changed, 619 insertions(+), 5 deletions(-) create mode 100644 tests/brevitas_examples/test_learned_round_utils.py diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 2f6df217a..45a77fb4f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,7 +26,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import List, Optional +from abc import ABC +from abc import abstractmethod +from typing import Callable, Generator, List, Optional, Tuple import numpy as np import torch @@ -95,6 +97,176 @@ def __call__(self, t): return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) +class LearnedRoundLoss(ABC): + + @abstractmethod + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + pass + + @abstractmethod + def format_loss_components(self, *args) -> str: + pass + + +class AdaRoundLoss(LearnedRoundLoss): + + def __init__( + self, + module: nn.Module, + learned_round_modules: List[nn.Module], + weight: float = 0.01, + max_count: int = 1000, + b_range: Tuple = (20, 2), + warmup: float = 0.2, + decay_start: float = 0.0) -> None: + super().__init__() + # AdaRound operates in a layer-wise manner, so integrity needs to be checked + assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." + assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." + + self.weight = weight + self.module = module + self.loss_start = max_count * warmup + self.temp_decay = LinearTempDecay( + max_count, + start_b=b_range[0], + end_b=b_range[1], + rel_start_decay=warmup + (1.0 - warmup) * decay_start) + self.iter = 0 + self.learned_round_module = learned_round_modules[0] + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + self.iter += 1 + + rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() + + if self.iter < self.loss_start: + b = self.temp_decay(self.iter) + round_loss = 0 + else: # 1 - |(h-0.5)*2|**b + b = self.temp_decay(self.iter) + round_vals = self.learned_round_module.p_forward() + round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() + + total_loss = rec_loss + round_loss + return total_loss, (total_loss, rec_loss, round_loss, b) + + def format_loss_components(self, loss: float, rec_loss: float, round_loss: float, b) -> str: + return "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( + loss, rec_loss, round_loss, b) + + +class AutoRoundLoss(LearnedRoundLoss): + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() + return loss, (loss,) + + def format_loss_components(self, loss: float) -> str: + return "loss = {:.4f}".format(loss) + + +class LearnedRound(ABC): + + def __init__(self, iters: int = 100) -> None: + self.iters = iters + + def _insert_learned_round_quantizer(self, block: nn.Module) -> None: + for module in block.modules(): + if isinstance(module, QuantWBIOL) and not find_learned_round_module(module): + self._insert_learned_round_quantizer_to_layer(module) + module.weight_quant.init_tensor_quant(preserve_state_dict=True) + + @abstractmethod + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + pass + + @abstractmethod + def _is_learned_round_module(self, module: nn.Module) -> bool: + pass + + @abstractmethod + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> LearnedRoundLoss: + pass + + def _find_learned_round_modules(self, block: nn.Module) -> List[nn.Module]: + round_modules = [] + for module in block.modules(): + if self._is_learned_round_module(module): + round_modules.append(module) + return round_modules + + def learned_round_iterator( + self, + blocks: List[nn.Module]) -> Generator[nn.Module, LearnedRoundLoss, List[nn.Module]]: + for block in blocks: + # Insert learned round quantizers into the appropiate submodules + self._insert_learned_round_quantizer(block) + # Freeze block parameters + for params in block.parameters(): + params.requires_grad = False + # Retrieve learned round modules + learned_round_modules = self._find_learned_round_modules(block) + # Enable gradient tracking in learned round modules + for round_module in learned_round_modules: + for params in round_module.parameters(): + params.requires_grad = True + block_loss = self._instantiate_loss(block, learned_round_modules) + yield block, block_loss, learned_round_modules + block.eval() + + +class AdaRound(LearnedRound): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def _is_learned_round_module(self, module: nn.Module) -> bool: + return isinstance(module, LearnedRoundSte) + + def _insert_learned_round_quantizer_to_layer( + self, + layer: nn.Module, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1) -> None: + floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) + delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight + value = -torch.log((learned_round_zeta - learned_round_gamma) / + (delta - learned_round_gamma) - 1) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, + learned_round_gamma=learned_round_gamma, + learned_round_zeta=learned_round_zeta, + learned_round_init=value) + + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AdaRoundLoss: + return AdaRoundLoss(block, learned_round_modules, max_count=self.iters) + + +class AutoRound(LearnedRound): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def _is_learned_round_module(self, module: nn.Module) -> bool: + return isinstance(module, AutoRoundSte) + + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + value = torch.zeros_like(layer.weight.data) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, + learned_round_init=value, + ) + + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AutoRoundLoss: + return AutoRoundLoss() + + +# TODO: Remove after validation class Loss: def __init__( @@ -134,6 +306,7 @@ def __call__(self, pred, tgt): return total_loss, rec_loss, round_loss, b +# TODO: Remove after validation def find_learned_round_module(module): for submodule in module.modules(): if isinstance(submodule, LearnedRoundSte): @@ -141,6 +314,7 @@ def find_learned_round_module(module): return False +# TODO: Remove after validation def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_gamma=-0.1): if isinstance(layer, QuantWBIOL): if not find_learned_round_module(layer): @@ -157,6 +331,7 @@ def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_ layer.weight_quant.init_tensor_quant(preserve_state_dict=True) +# TODO: Remove after validation def split_layers(model, layers): for module in model.children(): if isinstance(module, QuantWBIOL): @@ -165,6 +340,7 @@ def split_layers(model, layers): split_layers(module, layers) +# TODO: Remove after validation def insert_auto_round_quantizer_block_layers(block: nn.Module) -> None: # Iterate over the layers and insert AutoRound quantizer in QuantWBIOL layers for module in block.modules(): @@ -177,6 +353,7 @@ def insert_auto_round_quantizer_block_layers(block: nn.Module) -> None: module.weight_quant.init_tensor_quant(preserve_state_dict=True) +# TODO: Remove after validation def find_round_modules_in_block(block: nn.Module, class_round_quant=AutoRoundSte) -> List[nn.Module]: round_modules = [] @@ -186,6 +363,7 @@ def find_round_modules_in_block(block: nn.Module, return round_modules +# Remove after validation def auto_round_iterator(blocks: List[nn.Module]): for block in blocks: # Insert AutoRound quantizer in the block layers @@ -204,6 +382,7 @@ def auto_round_iterator(blocks: List[nn.Module]): block.eval() +# TODO: Remove after validation def learned_round_iterator(layers, iters=1000): for layer in layers: insert_learned_round_quantizer(layer) @@ -218,6 +397,20 @@ def learned_round_iterator(layers, iters=1000): layer.eval() +# TODO: Remove, fast-experimentation code +def auto_round_layerwise_iterator(layers): + for layer in layers: + insert_auto_round_quantizer_block_layers(layer) + + for p in layer.parameters(): + p.requires_grad = False + + learned_round_module = find_round_modules_in_block(layer)[0] + learned_round_module.value.requires_grad = True + yield layer, learned_round_module + layer.eval() + + def save_inp_out_data( model, module, @@ -252,9 +445,9 @@ def save_inp_out_data( else: cached[1].append(data_saver.output_store.detach().cpu()) if store_inp: - cached[0] = torch.cat([x for x in cached[0]]) + cached[0] = torch.cat([x for x in cached[0]], dim=0) if store_out: - cached[1] = torch.cat([x for x in cached[1]]) + cached[1] = torch.cat([x for x in cached[1]], dim=0) handle.remove() if disable_quant: disable_quant_class.enable_act_quantization(model, False) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index d2b2ab5b1..24542f40d 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -33,6 +33,7 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.optim.sign_sgd import SignSGD from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE @@ -83,8 +84,10 @@ from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound from brevitas_examples.imagenet_classification.ptq.learned_round_utils import auto_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import LearnedRound from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers @@ -657,6 +660,48 @@ def apply_gpfq( gpfq.update() +# TODO: Remove after debugging +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + auto_round_layerwise_iterator + + +def apply_auto_round_learning_layerwise( + model, dataloader, device, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): + layers = [] + split_layers(model, layers) + print(f"Total Iterations per layer {iters}") + print(f"Number of layers {len(layers)}") + + for layer_idx, (layer, + learned_round_module) in enumerate(auto_round_layerwise_iterator(layers)): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + # Start measuring time + start.record() + + optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) + _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) + all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) + max_size = len(all_fp_out) + mse_loss = nn.MSELoss() + pbar = tqdm(range(iters), desc='') + for _ in pbar: + idx = torch.randint(0, max_size, (dataloader.batch_size,)) + quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] + layer.train() + + optimizer.zero_grad() + quant_out = layer(quant_inp) + loss = mse_loss(quant_out, fp_out) + + loss.backward() + optimizer.step() + # Update progress bar + pbar.set_description( + "block = {:d}/{:d}, loss = {:.4f}".format(layer_idx + 1, len(layers), loss)) + pbar.update(1) + + def apply_learned_round_learning( model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): layers = [] @@ -664,7 +709,8 @@ def apply_learned_round_learning( print(f"Total Iterations per layer {iters}") print(f"Number of layers {len(layers)}") - for layer, layer_loss, learned_round_module in learned_round_iterator(layers, iters=iters): + for layer_idx, (layer, layer_loss, + learned_round_module) in enumerate(learned_round_iterator(layers, iters=iters)): optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) @@ -692,6 +738,56 @@ def _is_resnet_block(module: nn.Module, module_name: str) -> bool: return (re.search(r"layer\d+", module_name) is not None) +def _is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + + +def apply_learned_round_learning_generalized( + model: nn.Module, + dataloader: torch.utils.data.dataloader.DataLoader, + learned_round: LearnedRound = AdaRound, + optimizer_class=torch.optim.Adam, + iters: int = 1000, + optimizer_lr: float = 1e-1, + block_check_fn: Callable = _is_layer, +): + # Retrieve blocks using the appropiate function to check blocks + blocks = get_blocks(model, block_check_fn) + + print(f"Total Iterations per block {iters}") + print(f"Number of blocks {len(blocks)}") + + for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( + learned_round.learned_round_iterator(blocks)): + optimizer = optimizer = optimizer_class( + itertools.chain( + *[ + learned_round_module.parameters() + for learned_round_module in block_learned_round_modules]), + lr=optimizer_lr) + _, all_fp_out = save_inp_out_data(model, block, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) + all_quant_inp, _ = save_inp_out_data(model, block, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) + max_size = len(all_fp_out) + pbar = tqdm(range(iters), desc='') + for _ in pbar: + idx = torch.randint(0, max_size, (dataloader.batch_size,)) + quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] + block.train() + + optimizer.zero_grad() + quant_out = block(quant_inp) + loss, loss_components = block_loss(quant_out, fp_out) + + loss.backward() + optimizer.step() + # Update progress bar + pbar.set_description( + "block = {:d}/{:d}, {}".format( + block_idx + 1, len(blocks), + block_loss.format_loss_components(*loss_components))) + pbar.update(1) + + def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], bool]) -> List[nn.Module]: blocks = [] @@ -711,7 +807,7 @@ def _get_blocks(module: nn.Module): return blocks -def apply_auto_round_learning_debug( +def apply_auto_round_learning( model, dataloader, device, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): # Add message in case the range can be surpassed if optimizer_class == SignSGD and iters * optimizer_lr > 0.5: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 86e4f48fe..59d53c336 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -22,14 +22,22 @@ from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRound +from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_layer +from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_resnet_block from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_auto_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import \ apply_auto_round_learning_efficient +from brevitas_examples.imagenet_classification.ptq.ptq_common import \ + apply_auto_round_learning_layerwise from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning +from brevitas_examples.imagenet_classification.ptq.ptq_common import \ + apply_learned_round_learning_generalized from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate_bn from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model @@ -167,6 +175,11 @@ def validate_args(args): default='none', choices=['none', 'ada_round', 'auto_round'], help='Learned round type (default: none)') +parser.add_argument( + '--learned-round-mode', + default='layerwise', + choices=['layerwise', 'blockwise'], + help='Learned round mode (default: none)') parser.add_argument( '--learned-round-iters', default=1000, @@ -499,6 +512,33 @@ def main(): else: raise ValueError(f"{args.optimizer} is not a valid optimizer.") + if args.learned_round_mode == "layerwise": + block_check_fn = _is_layer + elif args.learned_round_mode == "blockwise": + # if args.learned_round_type == "ada_round": + # raise ValueError(f"Block-wise round is not available with AdaRound.") + block_check_fn = _is_resnet_block + + if args.learned_round_type != "none": + if args.learned_round_type =="auto_round": + learned_round = AutoRound(iters=args.learned_round_iters) + elif args.learned_round_type == "ada_round": + learned_round = AdaRound(iters=args.learned_round_iters) + + """ + apply_learned_round_learning_generalized( + model=quant_model, + dataloader=calib_loader, + learned_round=learned_round, + optimizer_class=optimizer_class, + iters=args.learned_round_iters, + optimizer_lr=args.learned_round_lr, + block_check_fn=block_check_fn + ) + """ + + # TODO: Remove after validation + if args.learned_round_type == "auto_round": print("Applying Auto Round:") apply_auto_round_learning( diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py new file mode 100644 index 000000000..6194b7c00 --- /dev/null +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -0,0 +1,285 @@ +from hypothesis import given +import pytest +import pytest_cases +from pytest_cases import fixture +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +from brevitas import config +from brevitas.core.function_wrapper.auto_round import AutoRoundSte +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import FloatToIntImplType +import brevitas.nn as qnn +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +from brevitas.quant_tensor.base_quant_tensor import QuantTensor +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRoundLoss +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRound +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRoundLoss +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data +from brevitas_examples.imagenet_classification.ptq.ptq_common import get_blocks + +config.IGNORE_MISSING_KEYS = True + +# TODO: Include some integration test +class TestLearnedRound: + + @fixture + def quant_model(): + + class QuantBlock(nn.Module): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = qnn.QuantLinear(in_features=in_features, out_features=hidden_dim) + self.layer2 = qnn.QuantLinear(in_features=hidden_dim, out_features=out_features) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) + + class TestQuantModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = QuantBlock(in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = QuantBlock(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = QuantBlock(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x) + out = self.hidden_mlp(out) + return self.out_proj_mlp(out) + + return TestQuantModel(in_features=2, out_features=1, hidden_dim=4) + + @fixture + def model(): + + class Block(nn.Module): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = nn.Linear(in_features=in_features, out_features=hidden_dim) + self.layer2 = nn.Linear(in_features=hidden_dim, out_features=out_features) + self.relu = F.relu + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) + + class TestModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = Block(in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = Block(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = Block(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x) + out = self.hidden_mlp(out) + return self.out_proj_mlp(out) + + return TestModel(in_features=2, out_features=1, hidden_dim=4) + + @fixture + def data_loader(): + + class TestDataset(Dataset): + def __init__(self): + self.data = torch.tensor([[1.0, 2.0]]) + self.labels = torch.tensor([0]) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx], self.labels[idx] + + return DataLoader(TestDataset(), batch_size=1, shuffle=False) + + def test_get_blocks(self, quant_model: nn.Module): + + def _is_block(module: nn.Module, module_name: str) -> bool: + return module_name in ["hidden_mlp"] + + expected_blocks = [quant_model.hidden_mlp] + blocks = get_blocks(quant_model, _is_block) + + assert expected_blocks == blocks + + def test_get_layers(self, quant_model: nn.Module): + + def _is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + + expected_layers = [ + quant_model.in_proj_mlp.layer1, quant_model.in_proj_mlp.layer2, + quant_model.hidden_mlp.layer1, quant_model.hidden_mlp.layer2, + quant_model.out_proj_mlp.layer1, quant_model.out_proj_mlp.layer2 + ] + layers = get_blocks(quant_model, _is_layer) + + assert expected_layers == layers + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + # NOTE: DataSaverHook always returns a torch.Tensor for the input, while for the output it can be either a torch.Tensor or + # a QuantTensor. Is this expected behaviour? For that reason, the argument _assert_type is included in _aux_check_tensors. + # Also, returning an empty list for the save_inp_out_data does not seem very natural, considering a tensors if the appropiate + # store option is activated. + @pytest.mark.parametrize("store_input", [True, False]) + @pytest.mark.parametrize("store_out", [True, False]) + @pytest.mark.parametrize("keep_gpu", [True, False]) + @pytest.mark.parametrize("disable_quant", [True, False]) + def test_save_inp_out_data(self, model, quant_model, data_loader, store_input, store_out, keep_gpu, disable_quant): + # Make sure that the quant and FP models share the same weights + quant_model.load_state_dict(model.state_dict()) + + model.eval() + model = model.cuda() + + quant_model.eval() + quant_model = quant_model.cuda() + + # Retrieve module from quant_model + module = quant_model.hidden_mlp + + cache_quant_partial_input = [] + cache_quant_partial_output = [] + + cache_fp_partial_input = [] + cache_fp_partial_output = [] + + def _aux_check_tensors(result_tensor, expected_tensor, keep_gpu, disable_quant, assert_type=False): + # Verify that tensor is of the appropiate type + if assert_type: + assert isinstance(result_tensor, torch.Tensor if disable_quant else QuantTensor) + # Extract value tensors + if isinstance(result_tensor, QuantTensor): + result_tensor, expected_tensor = result_tensor.value, expected_tensor.value + # Verify that tensor is in appropiate device + assert result_tensor.is_cuda == keep_gpu + # Make sure tensors are in the same device before comparison + if not keep_gpu: + expected_tensor = expected_tensor.cpu() + + assert torch.allclose(result_tensor, expected_tensor) + + # Compute ground truths inputs and outputs + with torch.no_grad(): + for batch_data, _ in data_loader: + batch_data = batch_data.cuda() + # Compute quant inputs to module + quant_partial_input = quant_model.in_proj_mlp(batch_data) + cache_quant_partial_input.append(quant_partial_input) + # Compute quant outputs of module + quant_partial_output = quant_model.hidden_mlp(quant_partial_input) + cache_quant_partial_output.append(quant_partial_output) + + # Compute FP inputs to module + fp_partial_input = model.in_proj_mlp(batch_data) + cache_fp_partial_input.append(fp_partial_input) + # Compute FP outputs of module + fp_partial_output = model.hidden_mlp(fp_partial_input) + cache_fp_partial_output.append(fp_partial_output) + + # Inputs and outputs are concatenated along the batch dimension. + # See https://github.com/quic/aimet/blob/7c9eded51e3d8328746e7ba4cf68c7162f841712/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/activation_sampler.py#L231 + cache_quant_partial_input = torch.cat(cache_quant_partial_input, dim=0) + cache_quant_partial_output = torch.cat(cache_quant_partial_output, dim=0) + + cache_fp_partial_input = torch.cat(cache_fp_partial_input, dim=0) + cache_fp_partial_output = torch.cat(cache_fp_partial_output, dim=0) + + # Retrieve input and output data + input_data, out_data = save_inp_out_data(quant_model, module, data_loader, store_input, store_out, keep_gpu, disable_quant) + # Verify that empty lists are returned + if store_input: + if disable_quant: + _aux_check_tensors(input_data, fp_partial_input, keep_gpu, disable_quant, assert_type=True) + else: + _aux_check_tensors(input_data, quant_partial_input, keep_gpu, disable_quant) + else: + assert len(input_data) == 0 + + if store_out: + if disable_quant: + _aux_check_tensors(out_data, fp_partial_output, keep_gpu, disable_quant) + else: + _aux_check_tensors(out_data, quant_partial_output, keep_gpu, disable_quant, assert_type=True) + else: + assert len(out_data) == 0 + + @pytest.mark.parametrize("learned_round_class, rounding_mode, float_to_int_impl", [(AutoRound, "AUTO_ROUND", AutoRoundSte), (AdaRound, "LEARNED_ROUND", LearnedRoundSte)]) + def test_insert_learned_round_quantizer(self, quant_model, learned_round_class, rounding_mode, float_to_int_impl): + block = quant_model.in_proj_mlp + learned_round = learned_round_class(iters=100) + learned_round._insert_learned_round_quantizer(block) + + for module in block.modules(): + if hasattr(module, "weight_quant"): + assert module.weight_quant.rounding_mode == rounding_mode + assert isinstance(module.weight_quant.tensor_quant.int_quant.float_to_int_impl, float_to_int_impl) + + @pytest.mark.parametrize("learned_round_class", [AutoRound, AdaRound]) + @pytest.mark.parametrize("block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) + def test_find_learned_round_modules(self, quant_model, learned_round_class, block_strs, num_round_modules): + learned_round = learned_round_class(iters=100) + # Inject quantizers in quant model + for block_str in block_strs: + block = getattr(quant_model, block_str) + learned_round._insert_learned_round_quantizer(block) + learned_round_modules = learned_round._find_learned_round_modules(quant_model) + assert len(learned_round_modules) == num_round_modules + + @pytest.mark.parametrize("learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss)]) + @pytest.mark.parametrize("block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) + def test_learned_round_iter_blockwise(self, quant_model, learned_round_class, learned_round_loss_class, block_strs, num_round_modules): + # Retrieve blocks from quant model + blocks = [getattr(quant_model, block_str) for block_str in block_strs] + learned_round = learned_round_class(iters=100) + + # Counters to verify that the generators returns the appropiate number of items + blocks_count = 0 + learned_round_modules_count = 0 + + for (block, block_loss, block_learned_round_modules) in learned_round.learned_round_iterator(blocks): + assert isinstance(block_loss, learned_round_loss_class) + + for learned_round_module in block_learned_round_modules: + for params in learned_round_module.parameters(): + assert params.requires_grad + + blocks_count += 1 + learned_round_modules_count += len(block_learned_round_modules) + + assert blocks_count == len(blocks) + assert learned_round_modules_count == num_round_modules + + @pytest.mark.parametrize("learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss), (AdaRound, AdaRoundLoss)]) + def test_learned_round_iter_layerwise(self, quant_model, learned_round_class, learned_round_loss_class): + # Retrieve blocks from quant model + blocks = [module for module in quant_model.modules() if isinstance(module, QuantWBIOL)] + learned_round = learned_round_class(iters=100) + + # Counters to verify that the generators returns the appropiate number of items + blocks_count = 0 + learned_round_modules_count = 0 + + for (block, block_loss, block_learned_round_modules) in learned_round.learned_round_iterator(blocks): + assert isinstance(block_loss, learned_round_loss_class) + + for learned_round_module in block_learned_round_modules: + for params in learned_round_module.parameters(): + assert params.requires_grad + + blocks_count += 1 + learned_round_modules_count += len(block_learned_round_modules) + + assert blocks_count == len(blocks) + assert learned_round_modules_count == len(blocks) From f60c295bbc5f1e454a73c1b119da43ad04945838 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 23 Oct 2024 17:29:00 +0100 Subject: [PATCH 04/48] Remove legacy code --- .../ptq/learned_round_utils.py | 148 +--------- .../imagenet_classification/ptq/ptq_common.py | 261 +----------------- .../ptq/ptq_evaluate.py | 32 +-- 3 files changed, 9 insertions(+), 432 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 45a77fb4f..607534bb6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -173,7 +173,8 @@ def __init__(self, iters: int = 100) -> None: def _insert_learned_round_quantizer(self, block: nn.Module) -> None: for module in block.modules(): - if isinstance(module, QuantWBIOL) and not find_learned_round_module(module): + if isinstance(module, QuantWBIOL) and len( + self._find_learned_round_modules(module)) == 0: self._insert_learned_round_quantizer_to_layer(module) module.weight_quant.init_tensor_quant(preserve_state_dict=True) @@ -266,151 +267,6 @@ def _instantiate_loss( return AutoRoundLoss() -# TODO: Remove after validation -class Loss: - - def __init__( - self, - module, - learned_round_module, - weight=0.01, - max_count=1000, - b_range=(20, 2), - warmup=0.2, - decay_start=0.0): - self.weight = weight - self.module = module - self.loss_start = max_count * warmup - self.temp_decay = LinearTempDecay( - max_count, - start_b=b_range[0], - end_b=b_range[1], - rel_start_decay=warmup + (1.0 - warmup) * decay_start) - self.iter = 0 - self.learned_round_module = learned_round_module - - def __call__(self, pred, tgt): - self.iter += 1 - - rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() - - if self.iter < self.loss_start: - b = self.temp_decay(self.iter) - round_loss = 0 - else: # 1 - |(h-0.5)*2|**b - b = self.temp_decay(self.iter) - round_vals = self.learned_round_module.p_forward() - round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() - - total_loss = rec_loss + round_loss - return total_loss, rec_loss, round_loss, b - - -# TODO: Remove after validation -def find_learned_round_module(module): - for submodule in module.modules(): - if isinstance(submodule, LearnedRoundSte): - return submodule - return False - - -# TODO: Remove after validation -def insert_learned_round_quantizer(layer, learned_round_zeta=1.1, learned_round_gamma=-0.1): - if isinstance(layer, QuantWBIOL): - if not find_learned_round_module(layer): - floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) - delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((learned_round_zeta - learned_round_gamma) / - (delta - learned_round_gamma) - 1) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, - learned_round_gamma=learned_round_gamma, - learned_round_zeta=learned_round_zeta, - learned_round_init=value) - layer.weight_quant.init_tensor_quant(preserve_state_dict=True) - - -# TODO: Remove after validation -def split_layers(model, layers): - for module in model.children(): - if isinstance(module, QuantWBIOL): - layers.append(module) - else: - split_layers(module, layers) - - -# TODO: Remove after validation -def insert_auto_round_quantizer_block_layers(block: nn.Module) -> None: - # Iterate over the layers and insert AutoRound quantizer in QuantWBIOL layers - for module in block.modules(): - if isinstance(module, QuantWBIOL): - value = torch.zeros_like(module.weight.data) - module.weight_quant.quant_injector = module.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, - learned_round_init=value, - ) - module.weight_quant.init_tensor_quant(preserve_state_dict=True) - - -# TODO: Remove after validation -def find_round_modules_in_block(block: nn.Module, - class_round_quant=AutoRoundSte) -> List[nn.Module]: - round_modules = [] - for module in block.modules(): - if isinstance(module, class_round_quant): - round_modules.append(module) - return round_modules - - -# Remove after validation -def auto_round_iterator(blocks: List[nn.Module]): - for block in blocks: - # Insert AutoRound quantizer in the block layers - insert_auto_round_quantizer_block_layers(block) - # Freeze block parameters - for params in block.parameters(): - params.requires_grad = False - # Retrieve learned round modules - learned_round_modules = find_round_modules_in_block(block, AutoRoundSte) - # Enable gradient tracking in learned round modules - for round_module in learned_round_modules: - for params in round_module.parameters(): - params.requires_grad = True - - yield block, learned_round_modules - block.eval() - - -# TODO: Remove after validation -def learned_round_iterator(layers, iters=1000): - for layer in layers: - insert_learned_round_quantizer(layer) - - for p in layer.parameters(): - p.requires_grad = False - - learned_round_module = find_learned_round_module(layer) - learned_round_module.value.requires_grad = True - layer_loss = Loss(module=layer, learned_round_module=learned_round_module, max_count=iters) - yield layer, layer_loss, learned_round_module - layer.eval() - - -# TODO: Remove, fast-experimentation code -def auto_round_layerwise_iterator(layers): - for layer in layers: - insert_auto_round_quantizer_block_layers(layer) - - for p in layer.parameters(): - p.requires_grad = False - - learned_round_module = find_round_modules_in_block(layer)[0] - learned_round_module.value.requires_grad = True - yield layer, learned_round_module - layer.eval() - - def save_inp_out_data( model, module, diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 24542f40d..45541fd51 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -4,13 +4,15 @@ from functools import partial import itertools import math +import re from typing import Callable, List from warnings import warn -from accelerate.utils.operations import send_to_device import torch from torch import nn import torch.backends.cudnn as cudnn +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from brevitas.core.function_wrapper.shape import OverBatchOverTensorView @@ -85,11 +87,8 @@ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import auto_round_iterator -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import LearnedRound from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers # Every element of the Batch will have its own scale factor and zero point @@ -660,81 +659,7 @@ def apply_gpfq( gpfq.update() -# TODO: Remove after debugging -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - auto_round_layerwise_iterator - - -def apply_auto_round_learning_layerwise( - model, dataloader, device, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): - layers = [] - split_layers(model, layers) - print(f"Total Iterations per layer {iters}") - print(f"Number of layers {len(layers)}") - - for layer_idx, (layer, - learned_round_module) in enumerate(auto_round_layerwise_iterator(layers)): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - # Start measuring time - start.record() - - optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) - _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - max_size = len(all_fp_out) - mse_loss = nn.MSELoss() - pbar = tqdm(range(iters), desc='') - for _ in pbar: - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - layer.train() - - optimizer.zero_grad() - quant_out = layer(quant_inp) - loss = mse_loss(quant_out, fp_out) - - loss.backward() - optimizer.step() - # Update progress bar - pbar.set_description( - "block = {:d}/{:d}, loss = {:.4f}".format(layer_idx + 1, len(layers), loss)) - pbar.update(1) - - -def apply_learned_round_learning( - model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): - layers = [] - split_layers(model, layers) - print(f"Total Iterations per layer {iters}") - print(f"Number of layers {len(layers)}") - - for layer_idx, (layer, layer_loss, - learned_round_module) in enumerate(learned_round_iterator(layers, iters=iters)): - optimizer = optimizer_class(learned_round_module.parameters(), lr=optimizer_lr) - _, all_fp_out = save_inp_out_data(model, layer, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, layer, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - max_size = len(all_fp_out) - pbar = tqdm(range(iters), desc='') - for i in pbar: - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - layer.train() - - optimizer.zero_grad() - quant_out = layer(quant_inp) - loss, rec_loss, round_loss, b = layer_loss(quant_out, fp_out) - - loss.backward() - optimizer.step() - pbar.set_description( - "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( - loss, rec_loss, round_loss, b)) - - -# TODO: Replace by an actual function. Remove def _is_resnet_block(module: nn.Module, module_name: str) -> bool: - import re return (re.search(r"layer\d+", module_name) is not None) @@ -742,11 +667,11 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: return isinstance(module, QuantWBIOL) -def apply_learned_round_learning_generalized( +def apply_learned_round_learning( model: nn.Module, - dataloader: torch.utils.data.dataloader.DataLoader, + dataloader: DataLoader, learned_round: LearnedRound = AdaRound, - optimizer_class=torch.optim.Adam, + optimizer_class: Optimizer = torch.optim.Adam, iters: int = 1000, optimizer_lr: float = 1e-1, block_check_fn: Callable = _is_layer, @@ -807,180 +732,6 @@ def _get_blocks(module: nn.Module): return blocks -def apply_auto_round_learning( - model, dataloader, device, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): - # Add message in case the range can be surpassed - if optimizer_class == SignSGD and iters * optimizer_lr > 0.5: - warn("It is possible that the weights are not rounded to their floor or ceil.") - - blocks = get_blocks(model, _is_resnet_block) - print(f"Total Iterations per block {iters}") - print(f"Number of blocks {len(blocks)}") - - # The following code reuses most of the code in auto_learned_round_learning. This implementation - # requires a forward pass through the whole net to capture inputs/outputs of a single block, - # while for auto_learned_round_learning_efficient a single forward is required. - - # Iterate over the blocks, keeping track of outputs to pass them to next block - for block_idx, (block, learned_round_modules) in enumerate(auto_round_iterator(blocks)): - # Instantiate optimizer with the parameters of the learned round modules - optimizer = optimizer_class( - itertools.chain( - *[ - learned_round_module.parameters() - for learned_round_module in learned_round_modules]), - lr=optimizer_lr) - # Use MSE loss to measure the discrepancy between quantised and unquantised outputs - mse_loss_fn = nn.MSELoss() - # Save inputs and outputs - _, all_fp_out = save_inp_out_data(model, block, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, block, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - - max_size = len(all_fp_out) - with tqdm(total=iters) as pbar: - for _ in range(iters): - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - block.train() - - optimizer.zero_grad() - quant_out = block(quant_inp) - loss = mse_loss_fn(quant_out, fp_out) - - loss.backward() - optimizer.step() - # Update progress bar - pbar.set_description( - "block = {:d}/{:d}, loss = {:.4f}".format(block_idx + 1, len(blocks), loss)) - pbar.update(1) - - -# TODO: Investigate performance drop -def apply_auto_round_learning_efficient( - model, dataloader, device, optimizer_class=SignSGD, iters=1000, optimizer_lr=1e-1): - # Add message in case the range can be surpassed - if optimizer_class == SignSGD and iters * optimizer_lr > 0.5: - warn("It is possible that the weights are not rounded to their floor or ceil.") - - blocks = get_blocks(model, _is_resnet_block) - print(f"Total Iterations per block {iters}") - print(f"Number of blocks {len(blocks)}") - - # NOTE: Note that we are storing the output for each batch, thus - # resulting in a memory cost proportional to the number of samples - # in the calibration set. It might be desirable to consider - # alternatives that enable rounding in a batch-wise manner. - cached_args, cached_kwargs = [], [] - - # Method to intercept the input to the first block - def intercept_input(module: nn.Module, args, kwargs): - args = send_to_device(args, 'cpu') - kwargs = send_to_device(kwargs, 'cpu') - cached_args.append(args) - cached_kwargs.append(kwargs) - raise StopFwdException - - # Method to intercept output of each block - def intercept_output(module: nn.Module, args, kwargs, output, cache: List): - if isinstance(output, tuple): - output = output[0] - output = send_to_device(output, 'cpu') - cache.append((output,)) - raise StopFwdException - - # Assumptions of the following code: - # 1: The order in the list blocks corresponds to the order in which - # each block forward is run, when the model forward is executed. - # 2: The input to each block is the output of the previous block. - - # Disable quantisation for retrieving FP inputs - # TODO: Check value for call_act_quantizer_impl - toggle_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=False) - toggle_quant_inference.disable_param_quantization(model, is_training=True) - toggle_quant_inference.disable_bias_quantization(model, is_training=True) - return_model_quant_tensor_state = disable_return_quant_tensor(model) - - # Capture inputs to the first block and store them in cached_args, cached_kwargs - hook = blocks[0].register_forward_pre_hook(intercept_input, with_kwargs=True) - with torch.no_grad(): - for img_batch, _ in dataloader: - try: - img_batch = img_batch.to(device) - model(img_batch) - except StopFwdException: - pass - hook.remove() - - # Enable quantisation for consistency - toggle_quant_inference.enable_param_quantization(model, is_training=True) - toggle_quant_inference.enable_bias_quantization(model, is_training=True) - restore_return_quant_tensor(model, return_model_quant_tensor_state) - - # Iterate over the blocks, keeping track of outputs to pass them to next block - for block_idx, (block, learned_round_modules) in enumerate(auto_round_iterator(blocks)): - # Instantiate optimizer with the parameters of the learned round modules - optimizer = optimizer_class( - itertools.chain( - *[ - learned_round_module.parameters() - for learned_round_module in learned_round_modules]), - lr=optimizer_lr) - # Use MSE loss to measure the discrepancy between quantised and unquantised outputs - mse_loss_fn = nn.MSELoss() - # Prevent needing to perform a deep copy - past_cached_args = cached_args - cached_args = [] - # Process each batch storing outputs - hook = block.register_forward_hook( - partial(intercept_output, cache=cached_args), with_kwargs=True) - # Disable quantisation to get FP outputs - toggle_quant_inference.disable_param_quantization(block, is_training=True) - toggle_quant_inference.disable_bias_quantization(block, is_training=True) - return_block_quant_tensor_state = disable_return_quant_tensor(block) - # Retrieve the FP outputs - with torch.no_grad(): - for args, kwargs in zip(past_cached_args, cached_kwargs): - try: - args = send_to_device(args, device) - kwargs = send_to_device(kwargs, device) - block(*args, **kwargs) - except StopFwdException: - pass - hook.remove() - # Enable quantisation to get Quant outputs - toggle_quant_inference.enable_param_quantization(block, is_training=True) - toggle_quant_inference.enable_bias_quantization(block, is_training=True) - restore_return_quant_tensor(block, return_block_quant_tensor_state) - - # Concatenate inputs and outputs along the batch dimension - past_cached_args_opt = tuple( - torch.cat(tensors, dim=0) for tensors in zip(*past_cached_args)) - cached_args_opt = tuple(torch.cat(tensors, dim=0) for tensors in zip(*cached_args)) - - # TODO: Verify with the implementation of AutoRound - with tqdm(total=iters) as pbar: - block.train() - for _ in range(iters): - # Subsample in calibration subset - idx = torch.randint(0, cached_args_opt[0].shape[0], (dataloader.batch_size,)) - fp_input, fp_out = past_cached_args_opt[0][idx], cached_args_opt[0][idx] - # Move tensor to appropiate devices - fp_input = send_to_device(fp_input, device) - fp_out = send_to_device(fp_out, device) - - quant_out = block(fp_input) - loss = mse_loss_fn(fp_out, quant_out) - - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # Update progress bar - pbar.set_description( - "block = {:d}/{:d}, loss = {:.4f}".format(block_idx + 1, len(blocks), loss)) - pbar.update(1) - - def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 59d53c336..afd37cee6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -27,17 +27,10 @@ from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_layer from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_resnet_block from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_auto_round_learning -from brevitas_examples.imagenet_classification.ptq.ptq_common import \ - apply_auto_round_learning_efficient -from brevitas_examples.imagenet_classification.ptq.ptq_common import \ - apply_auto_round_learning_layerwise from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning -from brevitas_examples.imagenet_classification.ptq.ptq_common import \ - apply_learned_round_learning_generalized from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate_bn from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model @@ -525,8 +518,7 @@ def main(): elif args.learned_round_type == "ada_round": learned_round = AdaRound(iters=args.learned_round_iters) - """ - apply_learned_round_learning_generalized( + apply_learned_round_learning( model=quant_model, dataloader=calib_loader, learned_round=learned_round, @@ -535,28 +527,6 @@ def main(): optimizer_lr=args.learned_round_lr, block_check_fn=block_check_fn ) - """ - - # TODO: Remove after validation - - if args.learned_round_type == "auto_round": - print("Applying Auto Round:") - apply_auto_round_learning( - quant_model, - calib_loader, - device, - optimizer_class=optimizer_class, - iters=args.learned_round_iters, - optimizer_lr=args.learned_round_lr) - - if args.learned_round_type == "ada_round": - print("Applying Learned Round (AdaRound):") - apply_learned_round_learning( - quant_model, - calib_loader, - optimizer_class=optimizer_class, - iters=args.learned_round_iters, - optimizer_lr=args.learned_round_lr) if args.calibrate_bn: print("Calibrate BN:") From d3d1d553ab1b0fa6faebf99c33799857af4457fa Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 24 Oct 2024 13:58:20 +0100 Subject: [PATCH 05/48] LLM learned round --- .../core/function_wrapper/auto_round.py | 2 +- .../ptq/learned_round_utils.py | 204 ++++++++ .../imagenet_classification/ptq/ptq_common.py | 22 +- .../llm/benchmark/llm_benchmark.py | 466 ++++++++++++++++++ .../llm/benchmark/parallel.sh | 1 + .../llm/benchmark/post_processing.py | 32 ++ src/brevitas_examples/llm/main.py | 12 + .../test_learned_round_utils.py | 93 +++- 8 files changed, 784 insertions(+), 48 deletions(-) create mode 100644 src/brevitas_examples/llm/benchmark/llm_benchmark.py create mode 100644 src/brevitas_examples/llm/benchmark/parallel.sh create mode 100644 src/brevitas_examples/llm/benchmark/post_processing.py diff --git a/src/brevitas/core/function_wrapper/auto_round.py b/src/brevitas/core/function_wrapper/auto_round.py index 7e7f40d6b..7f1688291 100644 --- a/src/brevitas/core/function_wrapper/auto_round.py +++ b/src/brevitas/core/function_wrapper/auto_round.py @@ -36,7 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # p should be between [-0.5, 0.5], so this learnable parameter decides whether to "ceil" or "floor" p = self.value p = self.tensor_slicer(p) - return round_ste(x + p.to(x.dtype)) + return round_ste(x + (p.to(x.dtype)).view_as(x)) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 607534bb6..b20e68397 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -30,6 +30,7 @@ from abc import abstractmethod from typing import Callable, Generator, List, Optional, Tuple +from accelerate.utils.operations import send_to_device import numpy as np import torch from torch import nn @@ -310,3 +311,206 @@ def save_inp_out_data( disable_quant_class.enable_param_quantization(model, False) restore_return_quant_tensor(model, return_quant_tensor_state) return cached + + +class DataSaverHookLLM: + + def __init__( + self, + cache_args: List, + cache_kwargs: List, + cache_outs: List, + store_args: bool = True, + store_kwargs: bool = True, + store_outs: bool = True, + keep_gpu: bool = True): + self.cache_args = cache_args + self.cache_kwargs = cache_kwargs + self.cache_outs = cache_outs + + self.store_args = store_args + self.store_kwargs = store_kwargs + self.store_outs = store_outs + + self.keep_gpu = keep_gpu + + def __call__(self, module, args, kwargs, output): + # NOTE: If args/kwargs are QuantTensors, should include logic to unpack their values + if isinstance(output, (tuple, list)): + output = output[0] + + # Store each element in the appropiate cache + for element_to_cache, should_cache, cache in zip( + [args, kwargs, output], + [self.store_args, self.store_kwargs, self.store_outs], + [self.cache_args, self.cache_kwargs, self.cache_outs] + ): + if should_cache: + if not self.keep_gpu: + element_to_cache = send_to_device(element_to_cache, 'cpu') + cache.append(element_to_cache) + + raise StopFwdException + + +def save_inp_out_data_llm( + model, + module, + dataloader: torch.utils.data.DataLoader, + cache_args: List, + cache_kwargs: List, + cache_outs: List, + store_args: bool = True, + store_kwargs: bool = False, + store_outs: bool = True, + keep_gpu: bool = True, + disable_quant=False) -> None: + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + device = next(model.parameters()).device + data_saver = DataSaverHookLLM( + cache_args, cache_kwargs, cache_outs, store_args, store_kwargs, store_outs, keep_gpu) + handle = module.register_forward_hook(data_saver, with_kwargs=True) + with torch.no_grad(): + for inps in dataloader: + try: + inps = send_to_device(inps, device) + model(**inps) + except StopFwdException: + pass + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + + +# TODO: Move imports to their appropiate place +import itertools + +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.opt.modeling_opt import OPTDecoderLayer + +from brevitas.optim.sign_sgd import SignSGD + + +def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], + bool]) -> List[nn.Module]: + blocks = [] + + # Iterating over .modules() might have been more readable but + # with this recursive implementation, once a block is reached, + # its subtree of modules is not expanded. + def _get_blocks(module: nn.Module): + for module_name, module_child in module.named_children(): + if block_check_fn(module_child, module_name): + blocks.append(module_child) + else: + _get_blocks(module_child) + + # Run recursive function that updates the list blocks + _get_blocks(model) + return blocks + + +def _is_block_llm(module: nn.Module, module_name: str) -> bool: + return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) + + +def apply_learned_round_learning_llm( + model: nn.Module, + dataloader: DataLoader, + learned_round: LearnedRound = AutoRound(iters=100), + optimizer_class: Optimizer = SignSGD, + iters: int = 100, + optimizer_lr: float = 5e-3, + block_check_fn: Callable = _is_block_llm, +): + # Disable the cache to prevent memory buildup + cache_state = model.config.use_cache + model.config.use_cache = False + # NOTE: Can be problematic is more than one GPU is used. + device = next(model.parameters()).device + # Retrieve blocks using the appropiate function to check blocks + blocks = get_blocks(model, block_check_fn) + + print(f"Total Iterations per block {iters}") + print(f"Number of blocks {len(blocks)}") + + cache_args, cache_kwargs, cache_outs = [], [], [] + + for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( + learned_round.learned_round_iterator(blocks)): + optimizer = optimizer = optimizer_class( + itertools.chain( + *[ + learned_round_module.parameters() + for learned_round_module in block_learned_round_modules]), + lr=optimizer_lr) + # Cache needs to be cleaned between blocks. No need to clear the + # kwargs cache, as this is only updates for the first block. + cache_args = [] + cache_outs = [] + # Save FP output + save_inp_out_data_llm( + model, + block, + dataloader, + cache_args, + cache_kwargs, + cache_outs, + store_args=False, + store_kwargs=False, + store_outs=True, + keep_gpu=True, + disable_quant=True) + # Save Quant input + save_inp_out_data_llm( + model, + block, + dataloader, + cache_args, + cache_kwargs, + cache_outs, + store_args=True, + store_kwargs=len(cache_kwargs) == 0, + store_outs=False, + keep_gpu=True, + disable_quant=False) + + pbar = tqdm(range(iters), desc='') + for _ in pbar: + idx = torch.randint(0, len(cache_args), (1,)) + args, kwargs, fp_out = cache_args[idx], cache_kwargs[idx], cache_outs[idx] + block.train() + + optimizer.zero_grad() + + args = send_to_device(args, device) + kwargs = send_to_device(kwargs, device) + fp_out = send_to_device(fp_out, device) + + quant_out = block(*args, **kwargs) + if isinstance(quant_out, tuple): + quant_out = quant_out[0] + + loss, loss_components = block_loss(quant_out, fp_out) + + loss.backward() + optimizer.step() + # Update progress bar + pbar.set_description( + "block = {:d}/{:d}, {}".format( + block_idx + 1, len(blocks), + block_loss.format_loss_components(*loss_components))) + pbar.update(1) + + # Restore cache state + model.config.use_cache = cache_state diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 45541fd51..4a6161279 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -87,6 +87,7 @@ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import get_blocks from brevitas_examples.imagenet_classification.ptq.learned_round_utils import LearnedRound from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data @@ -670,7 +671,7 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: def apply_learned_round_learning( model: nn.Module, dataloader: DataLoader, - learned_round: LearnedRound = AdaRound, + learned_round: LearnedRound = AdaRound(iters=1000), optimizer_class: Optimizer = torch.optim.Adam, iters: int = 1000, optimizer_lr: float = 1e-1, @@ -713,25 +714,6 @@ def apply_learned_round_learning( pbar.update(1) -def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], - bool]) -> List[nn.Module]: - blocks = [] - - # Iterating over .modules() might have been more readable but - # with this recursive implementation, once a block is reached, - # its subtree of modules is not expanded. - def _get_blocks(module: nn.Module): - for module_name, module_child in module.named_children(): - if block_check_fn(module_child, module_name): - blocks.append(module_child) - else: - _get_blocks(module_child) - - # Run recursive function that updates the list blocks - _get_blocks(model) - return blocks - - def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py new file mode 100644 index 000000000..711aa7dff --- /dev/null +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -0,0 +1,466 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +from functools import partial +from itertools import product +import os +import random +from types import SimpleNamespace + +import numpy as np +from optimum.amd.brevitas.accelerate_utils import offload_model +from optimum.amd.brevitas.accelerate_utils import remove_hooks +from optimum.amd.brevitas.data_utils import compute_perplexity +from optimum.exporters.onnx import onnx_export_from_model +import pandas as pd +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from brevitas import __version__ as brevitas_version +from brevitas import config +from brevitas import torch_version +from brevitas.export import export_torch_qcdq +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.quantize import layerwise_quantize +from brevitas_examples.common.generative.quantize import generate_quant_maps +from brevitas_examples.common.generative.quantize import generate_quantizers +from brevitas_examples.common.parse_utils import quant_format_validator +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + apply_learned_round_learning_llm +from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index +from brevitas_examples.imagenet_classification.ptq.utils import get_next_available_gpu +from brevitas_examples.imagenet_classification.utils import SEED +from brevitas_examples.imagenet_classification.utils import validate +from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction +from brevitas_examples.llm.llm_quant.calibrate import apply_calibration +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization +from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization +from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq +from brevitas_examples.llm.llm_quant.gpxq import apply_gptq +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear +from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 +from brevitas_examples.llm.llm_quant.run_utils import get_fx + +config.IGNORE_MISSING_KEYS = True + + +def parse_type(v, default_type): + if v == 'None': + return None + else: + return default_type(v) + + +def parse_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y'): + return True + elif v.lower() in ('no', 'false', 'f', 'n'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +class hashabledict(dict): + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +def unique(sequence): + seen = set() + return [x for x in sequence if not (x in seen or seen.add(x))] + + +# Torchvision models with top1 accuracy +LLM_TOP1_MAP = { + 'facebook/opt-125m': None, + 'meta-llama/Llama-2-7b-hf': None,} + +OPTIONS_DEFAULT = { + 'model': list(LLM_TOP1_MAP.keys()), # HF model name. Default: facebook/opt-125m. + 'seed': [0], # Seed for sampling the calibration data. Default: 0. + 'nsamples': [128], # Number of calibration data samples. Default: 128. + 'seqlen': [2048], # Sequence length. Default: 2048. + 'eval': [True], # Eval model PPL on the chosen Dataset. + 'dataset': ['c4'], # Dataset to use for quantization (default: wikitext2) + 'weight_bit_width': [8], # Weight bit width. Default: 8. + 'weight_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. + 'weight_scale_precision': ['float_scale' + ], # Whether scale is a float value or a po2. Default: po2. + 'weight_quant_type': ['sym'], # Weight quantization type. Default: asym. + 'weight_quant_format': ['int'], # Weight quantization type. Default: int. + 'weight_quant_granularity': [ + 'per_group'], # Granularity for scales/zero-point of weights. Default: per_group. + 'weight_group_dim': [ + None], # Override default group_dim for groupsize quantization. Default: layer-dependant + 'weight_group_size': [128], # Group size for per_group weight quantization. Default: 128. + 'quantize_weight_zero_point': [False], # Quantize weight zero-point. + 'input_bit_width': [None], # Input bit width. Default: None (disables input quantization). + 'input_quant_format': ['int'], # Input quantization type. Default: int. + 'input_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. + 'input_scale_precision': ['float_scale' + ], # Whether input scale is a float value or a po2. Default: float. + 'input_scale_type': ['static'], # Whether input scale is a static value or a dynamic value. + 'input_quant_type': ['asym'], # Input quantization type. Default: asym. + 'input_quant_granularity': [ + 'per_tensor'], # Granularity for scales/zero-point of inputs. Default: per_tensor. + 'input_group_size': [64], # Group size for per_group input quantization. Default: 64. + 'quantize_input_zero_point': [False], # Quantize input zero-point. + 'quantize_last_layer': [False], # Quantize last nn.Linear layer. + 'gptq': [False], # Apply GPTQ. + 'gpfq': [False], # Apply GPFQ. + 'gpxq_act_order': [False], # Apply GPXQ activation ordering. + 'gpxq_use_quant_activations': [False], # Use quantized activations in GPXQ. + 'gpxq_create_weight_orig': [False], # Create weight_orig in GPXQ. + 'act_calibration': [False], # Apply activation calibration. + 'bias_corr': [False], # Apply bias correction. + 'ln_affine_merge': [False], # Merge LN affine params. + 'no_quantize': [False], # Disable quantization. + 'no_float16': [False], # Disable float16 as base datatype and switch to float32. + 'replace_mha': [False], # Replace HuggingFace Attention with a quantizable version + 'weight_equalization': [ + False], # Apply weight equalization. Relevant to ReLU based models (e.g. OPT). + 'act_equalization': [None], # Apply activation equalization (SmoothQuant). + 'load_awq': [None], # Load the awq search results. + 'export_target': [None], # Model export. + 'export_prefix': [None], # Path prefix to use for the various export flows. + 'checkpoint_name': [None], # Filename to save checkpoint. + 'fuse_sequences': [False], # Whether to merge the dataset sequences. + 'learned_round': [None, "auto_round"] # Whether to use learned round. If `None`, RTN is used. +} + +parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') +parser.add_argument('idx', type=int) +for option_name, option_value in OPTIONS_DEFAULT.items(): + if isinstance(option_value[0], bool): + type_args = parse_bool + else: + type_args = partial(parse_type, default_type=type(option_value[0])) + parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args) + + +def main(): + args = parser.parse_args() + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + + args.gpu = get_gpu_index(args.idx) + print("Iter {}, GPU {}".format(args.idx, args.gpu)) + + try: + ptq_llm_models(args) + except Exception as E: + print("Exception at index {}: {}".format(args.idx, E)) + + +def ptq_llm_models(args): + # Generate all possible combinations, including invalid ones + + options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} + + combinations = list(product(*options.values())) + + configs = [] + for combination in combinations: + config_namespace = SimpleNamespace( + **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) + config_namespace = validate_config(config_namespace) + if config_namespace.is_valid: + configs.append(hashabledict(**config_namespace.__dict__)) + + configs = unique(configs) + + if args.idx > len(configs) - 1: + return + + config_namespace = SimpleNamespace(**configs[args.idx]) + print(config_namespace) + + if config_namespace.export_prefix is None: + config_namespace.export_prefix = f"{config_namespace.model.replace('/', '--')}" + + if config_namespace.no_float16: + dtype = torch.float32 + else: + dtype = torch.float16 + + kwargs = {"torch_dtype": dtype} + + if config_namespace.export_target == 'torch_qcdq': + kwargs['torchscript'] = True + + print("Model loading...") + model = AutoModelForCausalLM.from_pretrained(config_namespace.model, **kwargs) + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(config_namespace.model) + float_ppl = None + quant_ppl = None + + if config_namespace.load_awq: + from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq + awq_results = torch.load(config_namespace.load_awq, map_location="cpu") + with CastFloat16ToFloat32(): + apply_awq(model, awq_results) + + require_fx = True if config_namespace.weight_equalization or config_namespace.act_equalization == 'fx' or config_namespace.ln_affine_merge else False + + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + config_namespace.model, + dataset_name=config_namespace.dataset, + tokenizer=tokenizer, + nsamples=config_namespace.nsamples, + seqlen=config_namespace.seqlen, + split="train", + seed=config_namespace.seed, + require_fx=require_fx, + device=None, + fuse_sequences=config_namespace.fuse_sequences, + ) + + validation_loader = get_dataset_for_model( + config_namespace.model, + dataset_name=config_namespace.dataset, + tokenizer=tokenizer, + nsamples=config_namespace.nsamples, + seqlen=config_namespace.seqlen, + split="validation", + seed=config_namespace.seed, + require_fx=require_fx, + device=None, + fuse_sequences=config_namespace.fuse_sequences, + ) + + device = next(iter(model.parameters())).device + print("Data loaded.") + + if config_namespace.eval: + assert config_namespace.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" + print("Float model eval...") + model = offload_model(model) + float_ppl = compute_perplexity( + model, + validation_loader, + context_length=config_namespace.seqlen // 2, + tokenizer=tokenizer) + remove_hooks(model) + print(f"Float perplexity ({config_namespace.dataset}): {float_ppl:.3f}") + + if require_fx: + model = get_fx(model) + + # Apply LN affine merging before inserting MHA layers + # since currently there is support only for merging into Linear + if config_namespace.ln_affine_merge: + print("Apply LN affine merge...") + apply_layernorm_affine_merge(model, dtype) + print("LN affine merge applied.") + + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing + # with all the variability in HF implementations + if config_namespace.replace_mha: + print("Replace HF MHA with quantizable variants...") + model = replace_mha_with_quantizable_layers(model, dtype) + print("Replacing done.") + + if config_namespace.weight_equalization: + print("Apply weight equalization...") + # In case of float16 model, we need to offload to account for missing ops + model = offload_model(model) + apply_weight_equalization(model) + remove_hooks(model) + print("Weight equalization applied.") + + if config_namespace.act_equalization is not None: + offload_model(model) + print("Apply act equalization (SmoothQuant)...") + apply_act_equalization(model, config_namespace.act_equalization, calibration_loader) + print("Act equalization applied.") + remove_hooks(model) + + if not config_namespace.no_quantize: + name_blacklist = [] + print("Applying model quantization...") + linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers( + dtype=dtype, + weight_bit_width=config_namespace.weight_bit_width, + weight_param_method=config_namespace.weight_param_method, + weight_scale_precision=config_namespace.weight_scale_precision, + weight_quant_type=config_namespace.weight_quant_type, + weight_quant_granularity=config_namespace.weight_quant_granularity, + weight_group_size=config_namespace.weight_group_size, + weight_group_dim=config_namespace.weight_group_dim, + quantize_weight_zero_point=config_namespace.quantize_weight_zero_point, + weight_quant_format=config_namespace.weight_quant_format, + input_bit_width=config_namespace.input_bit_width, + input_quant_format=config_namespace.input_quant_format, + input_scale_precision=config_namespace.input_scale_precision, + input_scale_type=config_namespace.input_scale_type, + input_param_method=config_namespace.input_param_method, + input_quant_type=config_namespace.input_quant_type, + input_quant_granularity=config_namespace.input_quant_granularity, + input_group_size=config_namespace.input_group_size, + quantize_input_zero_point=config_namespace.quantize_input_zero_point, + device=device) + layer_map = generate_quant_maps( + linear_input_quant=linear_input_quant, + weight_quant=weight_quant, + input_quant=input_quant, + q_scaled_quant=q_scaled_quant, + k_transposed_quant=k_transposed_quant, + v_quant=v_quant, + attn_output_weights_quant=attn_output_weights_quant, + dtype=dtype, + device=device, + input_quant_format=config_namespace.input_quant_format, + quantize_embedding=False) + if not config_namespace.quantize_last_layer: + name_blacklist += ["lm_head", "embed_out"] + model = layerwise_quantize( + model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) + # Tie back first/last layer weights in case they got untied + print("Model quantization applied.") + + # If any equalization has taken places, the embedding layer and the fully connected one are + # not tied anymore, and they need to be treated as standalone, separate layers. + # In all other cases we can tie them back so to preserve memory. + if config_namespace.act_equalization is None and not require_fx: + model.tie_weights() + + if config_namespace.bias_corr: + model = add_zero_bias_to_linear(model) + + model = offload_model(model) + + if config_namespace.learned_round: + print("Applying learned round...") + apply_learned_round_learning_llm(model, calibration_loader) + print("Learned round applied.") + + if config_namespace.act_calibration: + print("Apply act calibration...") + apply_calibration(model, calibration_loader) + print("Act calibration applied.") + + if config_namespace.gptq: + print("Applying GPTQ...") + apply_gptq( + model, + calibration_loader, + act_order=config_namespace.gpxq_act_order, + use_quant_activations=config_namespace.gpxq_use_quant_activations, + create_weight_orig=config_namespace.gpxq_create_weight_orig) + print("GPTQ applied.") + + if config_namespace.gpfq: + print("Applying GPFQ...") + apply_gpfq(model, calibration_loader, act_order=config_namespace.gpxq_act_order) + print("GPFQ applied.") + + if config_namespace.bias_corr: + print("Applying bias correction...") + apply_bias_correction(model, calibration_loader) + print("Bias correction applied.") + + if config_namespace.eval: + print("Model eval...") + quant_ppl = compute_perplexity( + model, + validation_loader, + context_length=config_namespace.seqlen // 2, + tokenizer=tokenizer) + print(f"Quantized perplexity ({config_namespace.dataset}): {quant_ppl:.3f}") + remove_hooks(model) + + # Validate the quant_model on the validation dataloader + print("Starting validation") + + column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys()] + [ + 'FP perplexity', 'Quant perplexity', 'Torch version', 'Brevitas version'] + values = [v for _, v in config_namespace.__dict__.items()] + [ + float_ppl, quant_ppl, torch_version, brevitas_version] + torchvision_df = pd.DataFrame([values], columns=column_names) + + folder = './multirun/' + str(args.idx) + os.makedirs(folder, exist_ok=True) + torchvision_df.to_csv(os.path.join(folder, 'RESULTS_LLM.csv'), index=False) + + +def validate_config(config_namespace): + is_valid = True + + if not config_namespace.no_quantize: + if config_namespace.gptq and config_namespace.gpfq: + is_valid = False + if config_namespace.export_target is not None: + if config_namespace.input_quant_format != 'int': + is_valid = False + if config_namespace.export_target is not None and config_namespace.input_bit_width is not None: + if config_namespace.input_scale_type != 'static': + is_valid = False + if config_namespace.export_target == 'sharded_torchmlir_group_weight': + if config_namespace.weight_quant_granularity != 'per_group': + is_valid = False + if config_namespace.input_bit_width is not None: + is_valid = False + if config_namespace.quantize_weight_zero_point: + is_valid = False + if config_namespace.export_target == 'sharded_packed_torchmlir_group_weight': + if config_namespace.weight_quant_granularity != 'per_group': + is_valid = False + if config_namespace.input_bit_width is not None: + is_valid = False + if config_namespace.quantize_weight_zero_point: + is_valid = False + if config_namespace.export_target == 'onnx_qcdq': + if config_namespace.weight_quant_granularity == 'per_group': + if config_namespace.input_bit_width is not None: + is_valid = False + if config_namespace.weight_quant_type == 'asym': + if not config_namespace.quantize_weight_zero_point: + is_valid = False + if config_namespace.input_bit_width is not None and config_namespace.input_quant_type == 'asym': + if not config_namespace.quantize_input_zero_point: + is_valid = False + if config_namespace.export_target == 'torch_qcdq': + if config_namespace.weight_quant_granularity == 'per_group': + is_valid = False + if config_namespace.weight_quant_type == 'asym': + if not config_namespace.quantize_weight_zero_point: + is_valid = False + if config_namespace.input_bit_width is not None and config_namespace.input_quant_type == 'asym': + if not config_namespace.quantize_input_zero_point: + is_valid = False + if config_namespace.input_bit_width and config_namespace.input_scale_type == 'static': + if not config_namespace.act_calibration: + is_valid = False + if (config_namespace.weight_equalization or config_namespace.act_equalization == 'fx'): + if config_namespace.replace_mha: + if config_namespace.export_target == 'onnx_qcdq': + is_valid = False + else: + if config_namespace.export_target == 'torch_qcdq': + is_valid = False + + config_namespace.is_valid = is_valid + return config_namespace + + +if __name__ == '__main__': + main() diff --git a/src/brevitas_examples/llm/benchmark/parallel.sh b/src/brevitas_examples/llm/benchmark/parallel.sh new file mode 100644 index 000000000..32071014b --- /dev/null +++ b/src/brevitas_examples/llm/benchmark/parallel.sh @@ -0,0 +1 @@ +seq 0 10 | xargs -n1 -P3 -I{} sh -c 'HF_HUB_CACHE=/scratch/hf_models/ python llm_benchmark.py "$@"' _ {} diff --git a/src/brevitas_examples/llm/benchmark/post_processing.py b/src/brevitas_examples/llm/benchmark/post_processing.py new file mode 100644 index 000000000..ab33b15dd --- /dev/null +++ b/src/brevitas_examples/llm/benchmark/post_processing.py @@ -0,0 +1,32 @@ +import os + +import pandas as pd + + +def main(): + main_dir = './multirun' + + evals = next(os.walk(main_dir))[1] + df = None + for eval in evals: + full_path = os.path.join(main_dir, eval, 'RESULTS_LLM.csv') + if not os.path.exists(full_path): + continue + if df is None: + df = pd.read_csv(full_path) + else: + single_df = pd.read_csv(full_path) + df = pd.concat([df, single_df]) + df = df.sort_values(by=list(df.columns)) + df.to_csv('RESULTS_LLM.csv', index=False, mode='w') + + grouped_df = df.groupby([ + 'Model', 'Weight bit width', 'Weight quant granularity', 'Learned round']) + idx = grouped_df['Quant perplexity'].transform(max) == df['Quant perplexity'] + best_config_df = df[idx] + best_config_df = best_config_df.sort_values(by=['Model', 'Quant perplexity']) + best_config_df.to_csv('RESULTS_LLM_BEST_CONFIGS.csv', index=False, mode='w') + + +if __name__ == '__main__': + main() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8b3ae4888..93544a1e5 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -24,6 +24,8 @@ from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.parse_utils import quant_format_validator +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + apply_learned_round_learning_llm from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model @@ -367,6 +369,11 @@ def main(args): with torch.no_grad(): model(**calibration_loader[0]) + if args.learned_round: + print("Applying learned round...") + apply_learned_round_learning_llm(model, calibration_loader) + print("Learned round applied.") + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -658,6 +665,11 @@ def parse_args(args): help= "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", ) + parser.add_argument( + '--learned-round', + default=None, + choices=[None, 'auto_round'], + help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') return parser.parse_args(args) diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index 6194b7c00..d02585ba2 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -19,11 +19,12 @@ from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRoundLoss from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRound from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRoundLoss +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import get_blocks from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data -from brevitas_examples.imagenet_classification.ptq.ptq_common import get_blocks config.IGNORE_MISSING_KEYS = True + # TODO: Include some integration test class TestLearnedRound: @@ -31,6 +32,7 @@ class TestLearnedRound: def quant_model(): class QuantBlock(nn.Module): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: super().__init__() self.layer1 = qnn.QuantLinear(in_features=in_features, out_features=hidden_dim) @@ -44,11 +46,15 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.relu(out) class TestQuantModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: super().__init__() - self.in_proj_mlp = QuantBlock(in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) - self.hidden_mlp = QuantBlock(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) - self.out_proj_mlp = QuantBlock(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + self.in_proj_mlp = QuantBlock( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: out = self.in_proj_mlp(x) @@ -61,6 +67,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: def model(): class Block(nn.Module): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: super().__init__() self.layer1 = nn.Linear(in_features=in_features, out_features=hidden_dim) @@ -74,11 +81,15 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.relu(out) class TestModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: super().__init__() - self.in_proj_mlp = Block(in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) - self.hidden_mlp = Block(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) - self.out_proj_mlp = Block(in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + self.in_proj_mlp = Block( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: out = self.in_proj_mlp(x) @@ -91,6 +102,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: def data_loader(): class TestDataset(Dataset): + def __init__(self): self.data = torch.tensor([[1.0, 2.0]]) self.labels = torch.tensor([0]) @@ -119,10 +131,12 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: return isinstance(module, QuantWBIOL) expected_layers = [ - quant_model.in_proj_mlp.layer1, quant_model.in_proj_mlp.layer2, - quant_model.hidden_mlp.layer1, quant_model.hidden_mlp.layer2, - quant_model.out_proj_mlp.layer1, quant_model.out_proj_mlp.layer2 - ] + quant_model.in_proj_mlp.layer1, + quant_model.in_proj_mlp.layer2, + quant_model.hidden_mlp.layer1, + quant_model.hidden_mlp.layer2, + quant_model.out_proj_mlp.layer1, + quant_model.out_proj_mlp.layer2] layers = get_blocks(quant_model, _is_layer) assert expected_layers == layers @@ -136,7 +150,8 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: @pytest.mark.parametrize("store_out", [True, False]) @pytest.mark.parametrize("keep_gpu", [True, False]) @pytest.mark.parametrize("disable_quant", [True, False]) - def test_save_inp_out_data(self, model, quant_model, data_loader, store_input, store_out, keep_gpu, disable_quant): + def test_save_inp_out_data( + self, model, quant_model, data_loader, store_input, store_out, keep_gpu, disable_quant): # Make sure that the quant and FP models share the same weights quant_model.load_state_dict(model.state_dict()) @@ -155,7 +170,8 @@ def test_save_inp_out_data(self, model, quant_model, data_loader, store_input, s cache_fp_partial_input = [] cache_fp_partial_output = [] - def _aux_check_tensors(result_tensor, expected_tensor, keep_gpu, disable_quant, assert_type=False): + def _aux_check_tensors( + result_tensor, expected_tensor, keep_gpu, disable_quant, assert_type=False): # Verify that tensor is of the appropiate type if assert_type: assert isinstance(result_tensor, torch.Tensor if disable_quant else QuantTensor) @@ -201,7 +217,8 @@ def _aux_check_tensors(result_tensor, expected_tensor, keep_gpu, disable_quant, # Verify that empty lists are returned if store_input: if disable_quant: - _aux_check_tensors(input_data, fp_partial_input, keep_gpu, disable_quant, assert_type=True) + _aux_check_tensors( + input_data, fp_partial_input, keep_gpu, disable_quant, assert_type=True) else: _aux_check_tensors(input_data, quant_partial_input, keep_gpu, disable_quant) else: @@ -211,12 +228,16 @@ def _aux_check_tensors(result_tensor, expected_tensor, keep_gpu, disable_quant, if disable_quant: _aux_check_tensors(out_data, fp_partial_output, keep_gpu, disable_quant) else: - _aux_check_tensors(out_data, quant_partial_output, keep_gpu, disable_quant, assert_type=True) + _aux_check_tensors( + out_data, quant_partial_output, keep_gpu, disable_quant, assert_type=True) else: assert len(out_data) == 0 - @pytest.mark.parametrize("learned_round_class, rounding_mode, float_to_int_impl", [(AutoRound, "AUTO_ROUND", AutoRoundSte), (AdaRound, "LEARNED_ROUND", LearnedRoundSte)]) - def test_insert_learned_round_quantizer(self, quant_model, learned_round_class, rounding_mode, float_to_int_impl): + @pytest.mark.parametrize( + "learned_round_class, rounding_mode, float_to_int_impl", + [(AutoRound, "AUTO_ROUND", AutoRoundSte), (AdaRound, "LEARNED_ROUND", LearnedRoundSte)]) + def test_insert_learned_round_quantizer( + self, quant_model, learned_round_class, rounding_mode, float_to_int_impl): block = quant_model.in_proj_mlp learned_round = learned_round_class(iters=100) learned_round._insert_learned_round_quantizer(block) @@ -224,11 +245,15 @@ def test_insert_learned_round_quantizer(self, quant_model, learned_round_class, for module in block.modules(): if hasattr(module, "weight_quant"): assert module.weight_quant.rounding_mode == rounding_mode - assert isinstance(module.weight_quant.tensor_quant.int_quant.float_to_int_impl, float_to_int_impl) + assert isinstance( + module.weight_quant.tensor_quant.int_quant.float_to_int_impl, float_to_int_impl) @pytest.mark.parametrize("learned_round_class", [AutoRound, AdaRound]) - @pytest.mark.parametrize("block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) - def test_find_learned_round_modules(self, quant_model, learned_round_class, block_strs, num_round_modules): + @pytest.mark.parametrize( + "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), + (["in_proj_mlp", "out_proj_mlp"], 4)]) + def test_find_learned_round_modules( + self, quant_model, learned_round_class, block_strs, num_round_modules): learned_round = learned_round_class(iters=100) # Inject quantizers in quant model for block_str in block_strs: @@ -237,9 +262,18 @@ def test_find_learned_round_modules(self, quant_model, learned_round_class, bloc learned_round_modules = learned_round._find_learned_round_modules(quant_model) assert len(learned_round_modules) == num_round_modules - @pytest.mark.parametrize("learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss)]) - @pytest.mark.parametrize("block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) - def test_learned_round_iter_blockwise(self, quant_model, learned_round_class, learned_round_loss_class, block_strs, num_round_modules): + @pytest.mark.parametrize( + "learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss)]) + @pytest.mark.parametrize( + "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), + (["in_proj_mlp", "out_proj_mlp"], 4)]) + def test_learned_round_iter_blockwise( + self, + quant_model, + learned_round_class, + learned_round_loss_class, + block_strs, + num_round_modules): # Retrieve blocks from quant model blocks = [getattr(quant_model, block_str) for block_str in block_strs] learned_round = learned_round_class(iters=100) @@ -248,7 +282,8 @@ def test_learned_round_iter_blockwise(self, quant_model, learned_round_class, le blocks_count = 0 learned_round_modules_count = 0 - for (block, block_loss, block_learned_round_modules) in learned_round.learned_round_iterator(blocks): + for (block, block_loss, + block_learned_round_modules) in learned_round.learned_round_iterator(blocks): assert isinstance(block_loss, learned_round_loss_class) for learned_round_module in block_learned_round_modules: @@ -261,8 +296,11 @@ def test_learned_round_iter_blockwise(self, quant_model, learned_round_class, le assert blocks_count == len(blocks) assert learned_round_modules_count == num_round_modules - @pytest.mark.parametrize("learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss), (AdaRound, AdaRoundLoss)]) - def test_learned_round_iter_layerwise(self, quant_model, learned_round_class, learned_round_loss_class): + @pytest.mark.parametrize( + "learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss), + (AdaRound, AdaRoundLoss)]) + def test_learned_round_iter_layerwise( + self, quant_model, learned_round_class, learned_round_loss_class): # Retrieve blocks from quant model blocks = [module for module in quant_model.modules() if isinstance(module, QuantWBIOL)] learned_round = learned_round_class(iters=100) @@ -271,7 +309,8 @@ def test_learned_round_iter_layerwise(self, quant_model, learned_round_class, le blocks_count = 0 learned_round_modules_count = 0 - for (block, block_loss, block_learned_round_modules) in learned_round.learned_round_iterator(blocks): + for (block, block_loss, + block_learned_round_modules) in learned_round.learned_round_iterator(blocks): assert isinstance(block_loss, learned_round_loss_class) for learned_round_module in block_learned_round_modules: From 49c724b16011899f0f7558052aaaf0b10bb6d9d1 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 4 Nov 2024 11:07:22 +0000 Subject: [PATCH 06/48] Refactoring round methos --- .../learned_round/learned_round_method.py | 216 +++++++++ .../learned_round/learned_round_optimizer.py | 275 +++++++++++ .../ptq/learned_round_utils.py | 433 ++++++------------ .../imagenet_classification/ptq/ptq_common.py | 50 -- .../ptq/ptq_evaluate.py | 56 ++- .../llm/benchmark/llm_benchmark.py | 13 +- .../llm/benchmark/parallel.sh | 1 - .../llm/llm_quant/learned_round_utils.py | 217 +++++++++ src/brevitas_examples/llm/main.py | 13 +- .../test_learned_round_utils.py | 18 +- 10 files changed, 906 insertions(+), 386 deletions(-) create mode 100644 src/brevitas_examples/common/learned_round/learned_round_method.py create mode 100644 src/brevitas_examples/common/learned_round/learned_round_optimizer.py delete mode 100644 src/brevitas_examples/llm/benchmark/parallel.sh create mode 100644 src/brevitas_examples/llm/llm_quant/learned_round_utils.py diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py new file mode 100644 index 000000000..07316ef72 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -0,0 +1,216 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod +from typing import Generator, List, Tuple + +import torch +from torch import nn +import torch.nn.functional as F + +from brevitas.core.function_wrapper.auto_round import AutoRoundSte +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import FloatToIntImplType +from brevitas.inject.enum import LearnedRoundImplType +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL + + +class LearnedRoundLoss(ABC): + + @abstractmethod + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + pass + + @abstractmethod + def format_loss_components(self, *args) -> str: + pass + +class LearnedRound(ABC): + + def __init__(self, iters: int = 200, **kwargs) -> None: + self.iters = iters + + def _insert_learned_round_quantizer(self, block: nn.Module) -> None: + for module in block.modules(): + if isinstance(module, QuantWBIOL) and len( + self._find_learned_round_modules(module)) == 0: + self._insert_learned_round_quantizer_to_layer(module) + module.weight_quant.init_tensor_quant(preserve_state_dict=True) + + @abstractmethod + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + pass + + @abstractmethod + def _is_learned_round_module(self, module: nn.Module) -> bool: + pass + + @abstractmethod + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> LearnedRoundLoss: + pass + + def _find_learned_round_modules(self, block: nn.Module) -> List[nn.Module]: + round_modules = [] + for module in block.modules(): + if self._is_learned_round_module(module): + round_modules.append(module) + return round_modules + + def learned_round_iterator( + self, + blocks: List[nn.Module]) -> Generator[nn.Module, LearnedRoundLoss, List[nn.Module]]: + for block in blocks: + # Insert learned round quantizers into the appropiate submodules + self._insert_learned_round_quantizer(block) + # Freeze block parameters + for params in block.parameters(): + params.requires_grad = False + # Retrieve learned round modules + learned_round_modules = self._find_learned_round_modules(block) + # Enable gradient tracking in learned round modules + for round_module in learned_round_modules: + for params in round_module.parameters(): + params.requires_grad = True + block_loss = self._instantiate_loss(block, learned_round_modules) + yield block, block_loss, learned_round_modules + +class LinearTempDecay: + + def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): + self.t_max = t_max + self.start_decay = rel_start_decay * t_max + self.start_b = start_b + self.end_b = end_b + + def __call__(self, t): + if t < self.start_decay: + return self.start_b + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) + +class AdaRoundLoss(LearnedRoundLoss): + + def __init__( + self, + module: nn.Module, + learned_round_modules: List[nn.Module], + weight: float = 0.01, + max_count: int = 1000, + b_range: Tuple = (20, 2), + warmup: float = 0.2, + decay_start: float = 0.0 + ) -> None: + super().__init__() + # AdaRound operates in a layer-wise manner, so integrity needs to be checked + assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." + assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." + + self.weight = weight + self.module = module + self.loss_start = max_count * warmup + self.temp_decay = LinearTempDecay( + max_count, + start_b=b_range[0], + end_b=b_range[1], + rel_start_decay=warmup + (1.0 - warmup) * decay_start) + self.iter = 0 + self.learned_round_module = learned_round_modules[0] + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + self.iter += 1 + + rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() + + if self.iter < self.loss_start: + b = self.temp_decay(self.iter) + round_loss = 0 + else: # 1 - |(h-0.5)*2|**b + b = self.temp_decay(self.iter) + round_vals = self.learned_round_module.p_forward() + round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() + + total_loss = rec_loss + round_loss + return total_loss, (total_loss, rec_loss, round_loss, b) + + def format_loss_components(self, loss: float, rec_loss: float, round_loss: float, b) -> str: + return "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( + loss, rec_loss, round_loss, b) + +class AdaRound(LearnedRound): + + def __init__( + self, + iters: int = 200, + weight: float = 0.01, + b_range: Tuple = (20, 2), + warmup: float = 0.2, + decay_start: float = 0.0, + **kwargs, + ) -> None: + super().__init__(iters, **kwargs) + # Loss-related configuration + self.weight = weight + self.b_range = b_range + self.warmup = warmup + self.decay_start = decay_start + + def _is_learned_round_module(self, module: nn.Module) -> bool: + return isinstance(module, LearnedRoundSte) + + def _insert_learned_round_quantizer_to_layer( + self, + layer: nn.Module, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1) -> None: + floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) + delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight + value = -torch.log((learned_round_zeta - learned_round_gamma) / + (delta - learned_round_gamma) - 1) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, + learned_round_gamma=learned_round_gamma, + learned_round_zeta=learned_round_zeta, + learned_round_init=value) + + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AdaRoundLoss: + return AdaRoundLoss( + block, + learned_round_modules, + max_count=self.iters, + weight=self.weight, + warmup=self.warmup, + decay_start=self.decay_start, + ) + +class AutoRoundLoss(LearnedRoundLoss): + + def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: + loss = F.mse_loss(pred, tgt) + return loss, (loss,) + + def format_loss_components(self, loss: float) -> str: + return "loss = {:.4f}".format(loss) + +class AutoRound(LearnedRound): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def _is_learned_round_module(self, module: nn.Module) -> bool: + return isinstance(module, AutoRoundSte) + + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + value = torch.zeros_like(layer.weight.data) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, + learned_round_init=value, + ) + + def _instantiate_loss( + self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AutoRoundLoss: + return AutoRoundLoss() diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py new file mode 100644 index 000000000..f80cedb67 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -0,0 +1,275 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from abc import ABC +from abc import abstractmethod +import copy +import itertools +from typing import Any, Callable, Dict, List, Tuple +import warnings + +import torch +from torch import autocast +from torch import nn +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from brevitas import config +from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound + +config.IGNORE_MISSING_KEYS = True + +def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], + bool]) -> List[nn.Module]: + blocks = [] + + # Iterating over .modules() might have been more readable but + # with this recursive implementation, once a block is reached, + # its subtree of modules is not expanded. + def _get_blocks(module: nn.Module): + for module_name, module_child in module.named_children(): + if block_check_fn(module_child, module_name): + blocks.append(module_child) + else: + _get_blocks(module_child) + + # Run recursive function that updates the list blocks + _get_blocks(model) + return blocks + +class LearnedRoundModelUtils(ABC): + + def __init__(self) -> None: + pass + + @abstractmethod + def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: + pass + + @abstractmethod + def init_model_learned_round(self, model: nn.Module) -> None: + pass + + @abstractmethod + def finish_model_learned_round(self, model: nn.Module) -> None: + pass + + @abstractmethod + def init_cache(self) -> Any: + pass + + @abstractmethod + def populate_cache( + self, + cache: Any, + model: nn.Module, + block: nn.Module, + data_loader: DataLoader, + keep_gpu: bool = True, + **kwargs, + ) -> int: + pass + + @abstractmethod + def sample_cache( + self, + block: nn.Module, + cache: Any, + indices: torch.Tensor, + **kwargs, + ) -> Tuple[Any, torch.Tensor]: + pass + + @abstractmethod + def run_forward( + self, + block: nn.Module, + inputs: Any, + ) -> torch.Tensor: + pass + + @abstractmethod + def loss_scaler( + self, + loss: torch.Tensor, + ) -> torch.Tensor: + pass + +class LearnedRoundOptimizer: + + def __init__( + self, + learned_round: LearnedRound, + learned_round_utils: LearnedRoundModelUtils, + optimizer_class: Optimizer = SignSGD, + lr_scheduler_class: LRScheduler = LinearLR, + optimizer_lr: float = 5e-3, + batch_size: float = 8, + iters: int = 200, + use_best_model: bool = True, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + optimizer_kwargs: Dict = {}, + lr_scheduler_kwargs : Dict = { + "start_factor": 1.0, + "end_factor": 0.0, + "verbose": False, + } + ) -> None: + if learned_round.iters != iters: + warnings.warn( + "The number of iterations passed to the learned round optimiser is different " + "to that of the learned round method, which might lead to unexpected behaviour." + ) + self.learned_round = learned_round + self.learned_round_utils = learned_round_utils + self.optimizer_class = optimizer_class + self.lr_scheduler_class = lr_scheduler_class + self.optimizer_lr = optimizer_lr + self.batch_size = batch_size + self.iters = iters + self.use_best_model = use_best_model + self.use_amp = use_amp + self.amp_dtype = amp_dtype + self.optimizer_kwargs = optimizer_kwargs + + self.lr_scheduler_kwargs = lr_scheduler_kwargs + self.lr_scheduler_kwargs["total_iters"] = self.iters + + @torch.no_grad() + def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: + for n, m in block.named_modules(): + if n in round_params: + m.load_state_dict(round_params[n]) + + @torch.no_grad() + def _collect_round_params(self, block: nn.Module) -> Dict: + params = {} + for n, m in block.named_modules(): + if self.learned_round._is_learned_round_module(m): + params[n] = copy.deepcopy(m.state_dict()) + return params + + def _scale_loss_and_backward(self, loss: torch.Tensor) -> torch.Tensor: + scaled_loss = self.learned_round_utils.loss_scaler(loss) + scaled_loss.backward() + return scaled_loss + + def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: + optimizer.step() + optimizer.zero_grad() + if lr_scheduler: + lr_scheduler.step() + + def apply_learned_round( + self, + model: nn.Module, + data_loader: DataLoader, + block_check_fn: Callable = None, + keep_gpu: bool = True + ) -> None: + # Prepare model for optimization + self.learned_round_utils.init_model_learned_round(model) + + block_check_fn = block_check_fn if block_check_fn else self.learned_round_utils.default_block_check_fn + # Retrieve blocks using the appropiate function to check blocks + blocks = get_blocks(model, block_check_fn) + + print(f"Total Iterations per block {self.iters}") + print(f"Number of blocks {len(blocks)}") + + # Initialise cache to store partial inputs and outputs for each block + cache = self.learned_round_utils.init_cache() + + # Loop across blocks to optimise rounding within each + for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( + self.learned_round.learned_round_iterator(blocks)): + # Block needs to be in eval mode while the rounding is optimised + block.eval() + + # Initialise optimiser and LR scheduler + optimizer = self.optimizer_class( + itertools.chain( + *[ + learned_round_module.parameters() + for learned_round_module in block_learned_round_modules + ] + ), + lr=self.optimizer_lr, + **self.optimizer_kwargs, + ) + lr_scheduler = ( + self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) + if self.lr_scheduler_class + else None + ) + + # Variables needed for printing + best_loss = torch.finfo(torch.float).max + init_loss = -1.0 + last_best_iter = self.iters + + optimal_rounding_params = {} + + torch.cuda.empty_cache() + # Populate cache for the given block + n_samples = self.learned_round_utils.populate_cache( + cache, + model, + block, + data_loader, + keep_gpu=keep_gpu, + ) + + pbar = tqdm(range(self.iters), desc='') + for i in pbar: + # Sample mini-batch from cache + idxs = torch.randperm(n_samples)[:self.batch_size] + inputs, fp_outs = self.learned_round_utils.sample_cache(block, cache, idxs) + + # Run block forward to obtain quant outputs + quant_outs = self.learned_round_utils.run_forward(block, inputs) + + if self.use_amp: + with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=self.amp_dtype): + loss, loss_components = block_loss(quant_outs, fp_outs) + else: + loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) + + init_loss = loss.item() if i == 0 else init_loss + + if loss < best_loss: + best_loss = loss.item() + last_best_iter = i + 1 + if self.use_best_model: + optimal_rounding_params = self._collect_round_params(block) + + # Scale loss and perform gradient step + self._scale_loss_and_backward(loss) + self._step(optimizer, lr_scheduler) + + # Update progress bar + pbar.set_description( + "Block = {:d}/{:d}, {}".format( + block_idx + 1, len(blocks), + block_loss.format_loss_components(*loss_components))) + pbar.update(1) + + if self.use_best_model: + self._load_round_params(block, optimal_rounding_params) + else: + # Override if the model with the lowest training error is not used + best_loss = loss.item() + last_best_iter = self.iters + + print( + f"Quantized block {block_idx+1}/{len(blocks)}, " + f"loss iter 0: {init_loss:.6f} -> iter {last_best_iter}: {best_loss:.6f}" + ) + + # Finish optimisation + self.learned_round_utils.finish_model_learned_round(model) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index b20e68397..9fb51a17a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,26 +26,21 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from abc import ABC -from abc import abstractmethod -from typing import Callable, Generator, List, Optional, Tuple +import re +from typing import Any, Tuple from accelerate.utils.operations import send_to_device -import numpy as np import torch from torch import nn -import torch.nn.functional as F +from torch.utils.data.dataloader import DataLoader from brevitas import config -from brevitas.core.function_wrapper.auto_round import AutoRoundSte -from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor -from brevitas.inject.enum import FloatToIntImplType -from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils config.IGNORE_MISSING_KEYS = True @@ -55,64 +50,49 @@ class StopFwdException(Exception): pass -class DataSaverHook: - - def __init__(self, store_output: False): - self.store_output = store_output - self.input_store = None - self.output_store = None - - def __call__(self, module, input_batch, output_batch): - input_batch = input_batch[0] - if isinstance(input_batch, QuantTensor): - input_batch = input_batch.value +class LearnedRoundVisionUtils(LearnedRoundModelUtils): - if hasattr(input_batch, 'names') and 'N' in input_batch.names: - batch_dim = input_batch.names.index('N') + def __init__(self) -> None: + pass - input_batch.rename_(None) - input_batch = input_batch.transpose(0, batch_dim) - if self.store_output: - output_batch.rename_(None) - output_batch = output_batch.transpose(0, batch_dim) + def init_model_learned_round(self, model: nn.Module) -> None: + pass - if self.store_output: - self.output_store = output_batch - self.input_store = input_batch - raise StopFwdException + def finish_model_learned_round(self, model: nn.Module) -> None: + pass + def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: + return (re.search(r"layer\d+", module_name) is not None) -class LinearTempDecay: + class _DataSaverHook: - def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): - self.t_max = t_max - self.start_decay = rel_start_decay * t_max - self.start_b = start_b - self.end_b = end_b + def __init__(self, store_output: False): + self.store_output = store_output + self.input_store = None + self.output_store = None - def __call__(self, t): - if t < self.start_decay: - return self.start_b - else: - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) + def __call__(self, module, input_batch, output_batch): + input_batch = input_batch[0] + if isinstance(input_batch, QuantTensor): + input_batch = input_batch.value + if hasattr(input_batch, 'names') and 'N' in input_batch.names: + batch_dim = input_batch.names.index('N') -class LearnedRoundLoss(ABC): + input_batch.rename_(None) + input_batch = input_batch.transpose(0, batch_dim) + if self.store_output: + output_batch.rename_(None) + output_batch = output_batch.transpose(0, batch_dim) - @abstractmethod - def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: - pass - - @abstractmethod - def format_loss_components(self, *args) -> str: - pass - - -class AdaRoundLoss(LearnedRoundLoss): + if self.store_output: + self.output_store = output_batch + self.input_store = input_batch + raise StopFwdException - def __init__( + def _save_inp_out_data( self, + model: nn.Module, module: nn.Module, learned_round_modules: List[nn.Module], weight: float = 0.01, @@ -268,249 +248,104 @@ def _instantiate_loss( return AutoRoundLoss() -def save_inp_out_data( - model, - module, - dataloader: torch.utils.data.DataLoader, - store_inp=False, - store_out=False, - keep_gpu: bool = True, - disable_quant=False): - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) - device = next(model.parameters()).device - data_saver = DataSaverHook(store_output=store_out) - handle = module.register_forward_hook(data_saver) - cached = [[], []] - with torch.no_grad(): - for img, t in dataloader: - try: - _ = model(img.to(device)) - except StopFwdException: - pass - if store_inp: - if keep_gpu: - cached[0].append(data_saver.input_store.detach()) - else: - cached[0].append(data_saver.input_store.detach().cpu()) - if store_out: - if keep_gpu: - cached[1].append(data_saver.output_store.detach()) - else: - cached[1].append(data_saver.output_store.detach().cpu()) - if store_inp: - cached[0] = torch.cat([x for x in cached[0]], dim=0) - if store_out: - cached[1] = torch.cat([x for x in cached[1]], dim=0) - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - return cached - - -class DataSaverHookLLM: - - def __init__( + def _save_inp_out_data( self, - cache_args: List, - cache_kwargs: List, - cache_outs: List, - store_args: bool = True, - store_kwargs: bool = True, - store_outs: bool = True, - keep_gpu: bool = True): - self.cache_args = cache_args - self.cache_kwargs = cache_kwargs - self.cache_outs = cache_outs - - self.store_args = store_args - self.store_kwargs = store_kwargs - self.store_outs = store_outs - - self.keep_gpu = keep_gpu - - def __call__(self, module, args, kwargs, output): - # NOTE: If args/kwargs are QuantTensors, should include logic to unpack their values - if isinstance(output, (tuple, list)): - output = output[0] - - # Store each element in the appropiate cache - for element_to_cache, should_cache, cache in zip( - [args, kwargs, output], - [self.store_args, self.store_kwargs, self.store_outs], - [self.cache_args, self.cache_kwargs, self.cache_outs] - ): - if should_cache: - if not self.keep_gpu: - element_to_cache = send_to_device(element_to_cache, 'cpu') - cache.append(element_to_cache) - - raise StopFwdException - - -def save_inp_out_data_llm( - model, - module, - dataloader: torch.utils.data.DataLoader, - cache_args: List, - cache_kwargs: List, - cache_outs: List, - store_args: bool = True, - store_kwargs: bool = False, - store_outs: bool = True, + model: nn.Module, + module: nn.Module, + dataloader: DataLoader, + store_inp: bool = False, + store_out: bool = False, + keep_gpu: bool = True, + disable_quant: bool = False): + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + device = next(model.parameters()).device + data_saver = LearnedRoundVisionUtils._DataSaverHook(store_output=store_out) + handle = module.register_forward_hook(data_saver) + cached = [[], []] + with torch.no_grad(): + for img, t in dataloader: + try: + _ = model(img.to(device)) + except StopFwdException: + pass + if store_inp: + if keep_gpu: + cached[0].append(data_saver.input_store.detach()) + else: + cached[0].append(data_saver.input_store.detach().cpu()) + if store_out: + if keep_gpu: + cached[1].append(data_saver.output_store.detach()) + else: + cached[1].append(data_saver.output_store.detach().cpu()) + if store_inp: + cached[0] = torch.cat([x for x in cached[0]], dim=0) + if store_out: + cached[1] = torch.cat([x for x in cached[1]], dim=0) + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + return cached + + def init_cache(self) -> Any: + return [], [] + + def populate_cache( + self, + cache: Any, + model: nn.Module, + block: nn.Module, + data_loader: DataLoader, keep_gpu: bool = True, - disable_quant=False) -> None: - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) - - device = next(model.parameters()).device - data_saver = DataSaverHookLLM( - cache_args, cache_kwargs, cache_outs, store_args, store_kwargs, store_outs, keep_gpu) - handle = module.register_forward_hook(data_saver, with_kwargs=True) - with torch.no_grad(): - for inps in dataloader: - try: - inps = send_to_device(inps, device) - model(**inps) - except StopFwdException: - pass - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - - -# TODO: Move imports to their appropiate place -import itertools - -from torch.optim.optimizer import Optimizer -from torch.utils.data.dataloader import DataLoader -from tqdm import tqdm -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from transformers.models.opt.modeling_opt import OPTDecoderLayer - -from brevitas.optim.sign_sgd import SignSGD - - -def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], - bool]) -> List[nn.Module]: - blocks = [] - - # Iterating over .modules() might have been more readable but - # with this recursive implementation, once a block is reached, - # its subtree of modules is not expanded. - def _get_blocks(module: nn.Module): - for module_name, module_child in module.named_children(): - if block_check_fn(module_child, module_name): - blocks.append(module_child) - else: - _get_blocks(module_child) - - # Run recursive function that updates the list blocks - _get_blocks(model) - return blocks - - -def _is_block_llm(module: nn.Module, module_name: str) -> bool: - return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) - - -def apply_learned_round_learning_llm( - model: nn.Module, - dataloader: DataLoader, - learned_round: LearnedRound = AutoRound(iters=100), - optimizer_class: Optimizer = SignSGD, - iters: int = 100, - optimizer_lr: float = 5e-3, - block_check_fn: Callable = _is_block_llm, -): - # Disable the cache to prevent memory buildup - cache_state = model.config.use_cache - model.config.use_cache = False - # NOTE: Can be problematic is more than one GPU is used. - device = next(model.parameters()).device - # Retrieve blocks using the appropiate function to check blocks - blocks = get_blocks(model, block_check_fn) - - print(f"Total Iterations per block {iters}") - print(f"Number of blocks {len(blocks)}") - - cache_args, cache_kwargs, cache_outs = [], [], [] - - for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( - learned_round.learned_round_iterator(blocks)): - optimizer = optimizer = optimizer_class( - itertools.chain( - *[ - learned_round_module.parameters() - for learned_round_module in block_learned_round_modules]), - lr=optimizer_lr) - # Cache needs to be cleaned between blocks. No need to clear the - # kwargs cache, as this is only updates for the first block. - cache_args = [] - cache_outs = [] - # Save FP output - save_inp_out_data_llm( - model, - block, - dataloader, - cache_args, - cache_kwargs, - cache_outs, - store_args=False, - store_kwargs=False, - store_outs=True, - keep_gpu=True, - disable_quant=True) - # Save Quant input - save_inp_out_data_llm( - model, - block, - dataloader, - cache_args, - cache_kwargs, - cache_outs, - store_args=True, - store_kwargs=len(cache_kwargs) == 0, - store_outs=False, - keep_gpu=True, - disable_quant=False) - - pbar = tqdm(range(iters), desc='') - for _ in pbar: - idx = torch.randint(0, len(cache_args), (1,)) - args, kwargs, fp_out = cache_args[idx], cache_kwargs[idx], cache_outs[idx] - block.train() - - optimizer.zero_grad() - - args = send_to_device(args, device) - kwargs = send_to_device(kwargs, device) - fp_out = send_to_device(fp_out, device) - - quant_out = block(*args, **kwargs) - if isinstance(quant_out, tuple): - quant_out = quant_out[0] - - loss, loss_components = block_loss(quant_out, fp_out) - - loss.backward() - optimizer.step() - # Update progress bar - pbar.set_description( - "block = {:d}/{:d}, {}".format( - block_idx + 1, len(blocks), - block_loss.format_loss_components(*loss_components))) - pbar.update(1) - - # Restore cache state - model.config.use_cache = cache_state + **kwargs, + ) -> int: + cache_input, cache_output = cache + # Clear caches + cache_input.clear() + cache_output.clear() + + _, all_fp_out = self._save_inp_out_data(model, block, data_loader, store_inp=False, store_out=True, keep_gpu=keep_gpu, disable_quant=True) + all_quant_inp, _ = self._save_inp_out_data(model, block, data_loader, store_inp=True, store_out=True, keep_gpu=keep_gpu, disable_quant=False) + + # Add elements to the caches + cache_input.append(all_quant_inp) + cache_output.append(all_fp_out) + + # Number of samples + return all_fp_out.shape[0] + + def sample_cache( + self, + block: nn.Module, + cache: Any, + indices: torch.Tensor, + **kwargs, + ) -> Tuple[Any, torch.Tensor]: + cache_input, cache_output = cache + device = next(block.parameters()).device + + input, output = cache_input[0][indices], cache_output[0][indices] + input = send_to_device(input, device) + output = send_to_device(output, device) + + return input, output + + def run_forward( + self, + block: nn.Module, + inputs: Any, + ) -> torch.Tensor: + return block(inputs) + + def loss_scaler( + self, + loss: torch.Tensor, + ) -> torch.Tensor: + return loss diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 4a6161279..2cd44443b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -86,10 +86,6 @@ from brevitas_examples.common.axe import A2GPTQ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import get_blocks -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import LearnedRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data # Every element of the Batch will have its own scale factor and zero point @@ -668,52 +664,6 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: return isinstance(module, QuantWBIOL) -def apply_learned_round_learning( - model: nn.Module, - dataloader: DataLoader, - learned_round: LearnedRound = AdaRound(iters=1000), - optimizer_class: Optimizer = torch.optim.Adam, - iters: int = 1000, - optimizer_lr: float = 1e-1, - block_check_fn: Callable = _is_layer, -): - # Retrieve blocks using the appropiate function to check blocks - blocks = get_blocks(model, block_check_fn) - - print(f"Total Iterations per block {iters}") - print(f"Number of blocks {len(blocks)}") - - for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( - learned_round.learned_round_iterator(blocks)): - optimizer = optimizer = optimizer_class( - itertools.chain( - *[ - learned_round_module.parameters() - for learned_round_module in block_learned_round_modules]), - lr=optimizer_lr) - _, all_fp_out = save_inp_out_data(model, block, dataloader, store_inp=False, store_out=True, keep_gpu=True, disable_quant=True) - all_quant_inp, _ = save_inp_out_data(model, block, dataloader, store_inp=True, store_out=True, keep_gpu=True, disable_quant=False) - max_size = len(all_fp_out) - pbar = tqdm(range(iters), desc='') - for _ in pbar: - idx = torch.randint(0, max_size, (dataloader.batch_size,)) - quant_inp, fp_out = all_quant_inp[idx], all_fp_out[idx] - block.train() - - optimizer.zero_grad() - quant_out = block(quant_inp) - loss, loss_components = block_loss(quant_out, fp_out) - - loss.backward() - optimizer.step() - # Update progress bar - pbar.set_description( - "block = {:d}/{:d}, {}".format( - block_idx + 1, len(blocks), - block_loss.format_loss_components(*loss_components))) - pbar.update(1) - - def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index afd37cee6..240d08756 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -22,15 +22,17 @@ from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas.optim.sign_sgd import SignSGD -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import AdaRound +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + LearnedRoundVisionUtils from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_layer from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_resnet_block from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate_bn from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model @@ -183,6 +185,11 @@ def validate_args(args): default=1e-3, type=float, help='Learning rate for learned round (default: 1e-3)') +parser.add_argument( + '--learned-round-batch-size', + default=1, + type=int, + help='Learning rate for learned round (default: %(default)d)') parser.add_argument( '--act-quant-percentile', default=99.999, @@ -498,33 +505,38 @@ def main(): max_accumulator_bit_width=args.gpxq_accumulator_bit_width, max_accumulator_tile_size=args.gpxq_accumulator_tile_size) - if args.optimizer == "adam": - optimizer_class = torch.optim.Adam - elif args.optimizer == "sign_sgd": - optimizer_class = SignSGD - else: - raise ValueError(f"{args.optimizer} is not a valid optimizer.") - - if args.learned_round_mode == "layerwise": - block_check_fn = _is_layer - elif args.learned_round_mode == "blockwise": - # if args.learned_round_type == "ada_round": - # raise ValueError(f"Block-wise round is not available with AdaRound.") - block_check_fn = _is_resnet_block - if args.learned_round_type != "none": + # Initialisation of rounding method if args.learned_round_type =="auto_round": learned_round = AutoRound(iters=args.learned_round_iters) elif args.learned_round_type == "ada_round": learned_round = AdaRound(iters=args.learned_round_iters) - - apply_learned_round_learning( - model=quant_model, - dataloader=calib_loader, + # Optimizer to tune the + if args.optimizer == "adam": + optimizer_class = torch.optim.Adam + elif args.optimizer == "sign_sgd": + optimizer_class = SignSGD + else: + raise ValueError(f"{args.optimizer} is not a valid optimizer.") + # Granularity of the rounding blocks + if args.learned_round_mode == "layerwise": + block_check_fn = _is_layer + elif args.learned_round_mode == "blockwise": + block_check_fn = _is_resnet_block + + learned_round_vision_utils = LearnedRoundVisionUtils() + learned_round_optimiser = LearnedRoundOptimizer( learned_round=learned_round, + learned_round_utils=learned_round_vision_utils, optimizer_class=optimizer_class, - iters=args.learned_round_iters, + lr_scheduler_class= None if args.optimizer == "adam" else torch.optim.lr_scheduler.LinearLR, optimizer_lr=args.learned_round_lr, + batch_size=args.learned_round_batch_size, + iters=args.learned_round_iters, + ) + learned_round_optimiser.apply_learned_round( + model, + data_loader=calib_loader, block_check_fn=block_check_fn ) diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index 711aa7dff..e8aba1d73 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -31,9 +31,9 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.parse_utils import quant_format_validator -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - apply_learned_round_learning_llm from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index from brevitas_examples.imagenet_classification.ptq.utils import get_next_available_gpu from brevitas_examples.imagenet_classification.utils import SEED @@ -47,6 +47,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq +from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers @@ -350,7 +351,13 @@ def ptq_llm_models(args): if config_namespace.learned_round: print("Applying learned round...") - apply_learned_round_learning_llm(model, calibration_loader) + learned_round_llm_utils = LearnedRoundLLMUtils() + learned_round = AutoRound() + learned_round_optimiser = LearnedRoundOptimizer( + learned_round=learned_round, + learned_round_utils=learned_round_llm_utils + ) + learned_round_optimiser.apply_learned_round(model, calibration_loader) print("Learned round applied.") if config_namespace.act_calibration: diff --git a/src/brevitas_examples/llm/benchmark/parallel.sh b/src/brevitas_examples/llm/benchmark/parallel.sh deleted file mode 100644 index 32071014b..000000000 --- a/src/brevitas_examples/llm/benchmark/parallel.sh +++ /dev/null @@ -1 +0,0 @@ -seq 0 10 | xargs -n1 -P3 -I{} sh -c 'HF_HUB_CACHE=/scratch/hf_models/ python llm_benchmark.py "$@"' _ {} diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py new file mode 100644 index 000000000..05ab8c191 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -0,0 +1,217 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, List, Tuple + +from accelerate.utils.operations import send_to_device +import torch +from torch import nn +from torch.utils.data.dataloader import DataLoader +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.opt.modeling_opt import OPTDecoderLayer + +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils + + +class StopFwdException(Exception): + """Used to throw and catch an exception to stop traversing the graph.""" + pass + + +class LearnedRoundLLMUtils(LearnedRoundModelUtils): + + def __init__(self) -> None: + super(LearnedRoundLLMUtils, self).__init__() + self.llm_cache_state = None + + def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: + return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) + + class _DataSaverHookLLM: + + def __init__( + self, + cache_args: List, + cache_kwargs: List, + cache_outs: List, + store_args: bool = True, + store_kwargs: bool = True, + store_outs: bool = True, + keep_gpu: bool = True): + self.cache_args = cache_args + self.cache_kwargs = cache_kwargs + self.cache_outs = cache_outs + + self.store_args = store_args + self.store_kwargs = store_kwargs + self.store_outs = store_outs + + self.keep_gpu = keep_gpu + + def __call__(self, module, args, kwargs, output): + # NOTE: If args/kwargs are QuantTensors, should include logic to unpack their values + if isinstance(output, (tuple, list)): + output = output[0] + + # Store each element in the appropiate cache + for element_to_cache, should_cache, cache in zip( + [args, kwargs, output], + [self.store_args, self.store_kwargs, self.store_outs], + [self.cache_args, self.cache_kwargs, self.cache_outs] + ): + if should_cache: + if not self.keep_gpu: + element_to_cache = send_to_device(element_to_cache, 'cpu') + cache.append(element_to_cache) + + raise StopFwdException + + def _save_inp_out_data( + self, + model: nn.Module, + module: nn.Module, + dataloader: DataLoader, + cache_args: List, + cache_kwargs: List, + cache_outs: List, + store_args: bool = True, + store_kwargs: bool = False, + store_outs: bool = True, + keep_gpu: bool = True, + disable_quant=False) -> None: + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + device = next(module.parameters()).device + data_saver = LearnedRoundLLMUtils._DataSaverHookLLM( + cache_args, cache_kwargs, cache_outs, store_args, store_kwargs, store_outs, keep_gpu) + handle = module.register_forward_hook(data_saver, with_kwargs=True) + with torch.no_grad(): + for inps in dataloader: + try: + inps = send_to_device(inps, device) + model(**inps) + except StopFwdException: + pass + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + + def init_model_learned_round(self, model: nn.Module) -> None: + self.llm_cache_state = model.config.use_cache + model.config.use_cache = False + + def finish_model_learned_round(self, model: nn.Module) -> None: + model.config.use_cache = self.llm_cache_state + self.llm_cache_state = None + + def init_cache(self) -> Any: + # cache_args, cache_kwargs, cache_outs + return [], [], [] + + def populate_cache( + self, + cache: Any, + model: nn.Module, + block: nn.Module, + data_loader: DataLoader, + keep_gpu: bool = True, + **kwargs, + ) -> int: + # Unpack cache + cache_args, cache_kwargs, cache_outs = cache + # Cache needs to be cleaned between blocks. No need to clear the + # kwargs cache, as this is only updated for the first block. + cache_args.clear() + cache_outs.clear() + # Save FP output + self._save_inp_out_data( + model, + block, + data_loader, + cache_args, + cache_kwargs, + cache_outs, + store_args=False, + store_kwargs=False, + store_outs=True, + keep_gpu=keep_gpu, + disable_quant=True) + # Save Quant input + self._save_inp_out_data( + model, + block, + data_loader, + cache_args, + cache_kwargs, + cache_outs, + store_args=True, + store_kwargs=len(cache_kwargs) == 0, + store_outs=False, + keep_gpu=keep_gpu, + disable_quant=False) + # Return number of samples in calibration set + return len(cache_args) + + def sample_cache( + self, + block: nn.Module, + cache: Any, + indices: torch.Tensor, + input_dim: int = 0, + **kwargs_fn, + ) -> Tuple[Any, torch.Tensor]: + cache_args, cache_kwargs, cache_outs = cache + device = next(block.parameters()).device + # Positional arguments + args = [cache_args[i] for i in indices] + args = tuple(torch.cat(arg_tensor, dim=0) for arg_tensor in zip(*args)) + # Keyword arguments + kwargs_dict = [cache_kwargs[i] for i in indices] + kwargs = {} + for curr_dict in kwargs_dict: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=input_dim) + # FP outputs + outs = torch.cat([cache_outs[i] for i in indices], dim=input_dim) + # Make sure that the inputs and outputs are in the same device as block, + # before running its forward pass. + args = send_to_device(args, device) + kwargs = send_to_device(kwargs, device) + outs = send_to_device(outs, device) + + return (args, kwargs), outs + + def run_forward( + self, + block: nn.Module, + inputs: Any, + ) -> torch.Tensor: + args, kwargs = inputs + quant_outs = block(*args, **kwargs) + if isinstance(quant_outs, tuple): + quant_outs = quant_outs[0] + return quant_outs + + def loss_scaler( + self, + loss: torch.Tensor, + ) -> torch.Tensor: + return loss * 1000 diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 93544a1e5..d62212e07 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -23,9 +23,9 @@ from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.parse_utils import quant_format_validator -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - apply_learned_round_learning_llm from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model @@ -36,6 +36,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq +from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch @@ -371,7 +372,13 @@ def main(args): if args.learned_round: print("Applying learned round...") - apply_learned_round_learning_llm(model, calibration_loader) + + learned_round_llm_utils = LearnedRoundLLMUtils() + learned_round = AutoRound() + learned_round_optimiser = LearnedRoundOptimizer( + learned_round=learned_round, learned_round_utils=learned_round_llm_utils) + learned_round_optimiser.apply_learned_round(model, calibration_loader) + print("Learned round applied.") if args.act_calibration: diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index d02585ba2..8999d4125 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -15,17 +15,17 @@ import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor.base_quant_tensor import QuantTensor -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AdaRoundLoss -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRound -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import AutoRoundLoss -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import get_blocks -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data +from brevitas_examples.common.learned_round.learned_round_method import AdaRound +from brevitas_examples.common.learned_round.learned_round_method import AdaRoundLoss +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import AutoRoundLoss +from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + LearnedRoundVisionUtils config.IGNORE_MISSING_KEYS = True -# TODO: Include some integration test class TestLearnedRound: @fixture @@ -152,6 +152,8 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: @pytest.mark.parametrize("disable_quant", [True, False]) def test_save_inp_out_data( self, model, quant_model, data_loader, store_input, store_out, keep_gpu, disable_quant): + # Initialise utils to save tensors + learned_round_vision_utils = LearnedRoundVisionUtils() # Make sure that the quant and FP models share the same weights quant_model.load_state_dict(model.state_dict()) @@ -213,7 +215,7 @@ def _aux_check_tensors( cache_fp_partial_output = torch.cat(cache_fp_partial_output, dim=0) # Retrieve input and output data - input_data, out_data = save_inp_out_data(quant_model, module, data_loader, store_input, store_out, keep_gpu, disable_quant) + input_data, out_data = learned_round_vision_utils._save_inp_out_data(quant_model, module, data_loader, store_input, store_out, keep_gpu, disable_quant) # Verify that empty lists are returned if store_input: if disable_quant: From 179357561c9f7bdec4e74315eac97055a0dc379e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 4 Nov 2024 19:25:05 +0000 Subject: [PATCH 07/48] Remove unused import --- src/brevitas/optim/sign_sgd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index dd5d62365..c34279b81 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -9,7 +9,6 @@ from torch.optim.optimizer import _fused_doc from torch.optim.optimizer import _maximize_doc from torch.optim.optimizer import _use_grad_for_differentiable -from torch.optim.optimizer import DeviceDict from torch.optim.optimizer import Optimizer from torch.utils._foreach_utils import _get_fused_kernels_supported_devices From 1c0720c5b43fe941502c1ae37d442667ada24622 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 4 Nov 2024 20:13:54 +0000 Subject: [PATCH 08/48] Fix license and refactor benchmark --- .../llm/benchmark/llm_benchmark.py | 413 ++++-------------- tests/brevitas/optim/test_sign_sgd.py | 43 +- 2 files changed, 118 insertions(+), 338 deletions(-) diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index e8aba1d73..19efbafc6 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -1,58 +1,65 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" import argparse from functools import partial from itertools import product import os -import random from types import SimpleNamespace -import numpy as np -from optimum.amd.brevitas.accelerate_utils import offload_model -from optimum.amd.brevitas.accelerate_utils import remove_hooks -from optimum.amd.brevitas.data_utils import compute_perplexity -from optimum.exporters.onnx import onnx_export_from_model import pandas as pd -import torch import torch.backends.cudnn as cudnn import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer from brevitas import __version__ as brevitas_version from brevitas import config from brevitas import torch_version -from brevitas.export import export_torch_qcdq -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas.graph.quantize import layerwise_quantize -from brevitas_examples.common.generative.quantize import generate_quant_maps -from brevitas_examples.common.generative.quantize import generate_quantizers -from brevitas_examples.common.learned_round.learned_round_method import AutoRound -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer -from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index -from brevitas_examples.imagenet_classification.ptq.utils import get_next_available_gpu -from brevitas_examples.imagenet_classification.utils import SEED -from brevitas_examples.imagenet_classification.utils import validate -from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction -from brevitas_examples.llm.llm_quant.calibrate import apply_calibration -from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model -from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization -from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization -from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq -from brevitas_examples.llm.llm_quant.gpxq import apply_gptq -from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge -from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear -from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers -from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 -from brevitas_examples.llm.llm_quant.run_utils import get_fx +# LLM example depends on optimum-amd, which requires PyTorch>=2.2 +from brevitas_examples.llm.main import main as main_llm +from brevitas_examples.llm.main import validate config.IGNORE_MISSING_KEYS = True @@ -97,7 +104,8 @@ def unique(sequence): 'nsamples': [128], # Number of calibration data samples. Default: 128. 'seqlen': [2048], # Sequence length. Default: 2048. 'eval': [True], # Eval model PPL on the chosen Dataset. - 'dataset': ['c4'], # Dataset to use for quantization (default: wikitext2) + 'dataset': ['wikitext2'], # Dataset to use for quantization (default: wikitext2) + 'gpxq_block_name': [None], # Block name for faster GPxQ optimization. Default: None 'weight_bit_width': [8], # Weight bit width. Default: 8. 'weight_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. 'weight_scale_precision': ['float_scale' @@ -106,6 +114,7 @@ def unique(sequence): 'weight_quant_format': ['int'], # Weight quantization type. Default: int. 'weight_quant_granularity': [ 'per_group'], # Granularity for scales/zero-point of weights. Default: per_group. + 'scale_rounding_func_type': [None], # Rounding function to use with Po2 scale. Default: None. 'weight_group_dim': [ None], # Override default group_dim for groupsize quantization. Default: layer-dependant 'weight_group_size': [128], # Group size for per_group weight quantization. Default: 128. @@ -124,9 +133,11 @@ def unique(sequence): 'quantize_last_layer': [False], # Quantize last nn.Linear layer. 'gptq': [False], # Apply GPTQ. 'gpfq': [False], # Apply GPFQ. - 'gpxq_act_order': [False], # Apply GPXQ activation ordering. - 'gpxq_use_quant_activations': [False], # Use quantized activations in GPXQ. - 'gpxq_create_weight_orig': [False], # Create weight_orig in GPXQ. + 'gpxq_act_order': [False], # Apply GPxQ activation ordering. + 'gpxq_use_quant_activations': [False], # Use quantized activations in GPxQ. + 'gpxq_create_weight_orig': [False], # Create weight_orig in GPxQ. + 'gpxq_max_accumulator_bit_width': [None], # Maximum accumulator bit width for GPxQ using AXE. + 'gpxq_max_accumulator_tile_size': [None], # Maximum accumulator tile size for GPxQ using AXE. 'act_calibration': [False], # Apply activation calibration. 'bias_corr': [False], # Apply bias correction. 'ln_affine_merge': [False], # Merge LN affine params. @@ -141,7 +152,7 @@ def unique(sequence): 'export_prefix': [None], # Path prefix to use for the various export flows. 'checkpoint_name': [None], # Filename to save checkpoint. 'fuse_sequences': [False], # Whether to merge the dataset sequences. - 'learned_round': [None, "auto_round"] # Whether to use learned round. If `None`, RTN is used. + 'learned_round': [None, "auto_round"], # Whether to use learned round. If `None`, RTN is used. } parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') @@ -156,33 +167,20 @@ def unique(sequence): def main(): args = parser.parse_args() - random.seed(SEED) - np.random.seed(SEED) - torch.manual_seed(SEED) - - args.gpu = get_gpu_index(args.idx) - print("Iter {}, GPU {}".format(args.idx, args.gpu)) - - try: - ptq_llm_models(args) - except Exception as E: - print("Exception at index {}: {}".format(args.idx, E)) - - -def ptq_llm_models(args): - # Generate all possible combinations, including invalid ones + # Generate all possible configurations, including invalid ones options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} - combinations = list(product(*options.values())) - configs = [] for combination in combinations: config_namespace = SimpleNamespace( **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) - config_namespace = validate_config(config_namespace) - if config_namespace.is_valid: + try: + validate(config_namespace) configs.append(hashabledict(**config_namespace.__dict__)) + except AssertionError: + # Invalid configuration + pass configs = unique(configs) @@ -190,283 +188,26 @@ def ptq_llm_models(args): return config_namespace = SimpleNamespace(**configs[args.idx]) - print(config_namespace) + args.gpu = get_gpu_index(args.idx) + print("Iter {}, GPU {}".format(args.idx, args.gpu)) - if config_namespace.export_prefix is None: - config_namespace.export_prefix = f"{config_namespace.model.replace('/', '--')}" + try: + float_ppl, quant_ppl, _ = main_llm(config_namespace) - if config_namespace.no_float16: - dtype = torch.float32 - else: - dtype = torch.float16 - - kwargs = {"torch_dtype": dtype} - - if config_namespace.export_target == 'torch_qcdq': - kwargs['torchscript'] = True - - print("Model loading...") - model = AutoModelForCausalLM.from_pretrained(config_namespace.model, **kwargs) - print("Model loaded.") - model.eval() - tokenizer = AutoTokenizer.from_pretrained(config_namespace.model) - float_ppl = None - quant_ppl = None - - if config_namespace.load_awq: - from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq - awq_results = torch.load(config_namespace.load_awq, map_location="cpu") - with CastFloat16ToFloat32(): - apply_awq(model, awq_results) - - require_fx = True if config_namespace.weight_equalization or config_namespace.act_equalization == 'fx' or config_namespace.ln_affine_merge else False - - # Load the data for calibration and evaluation. - calibration_loader = get_dataset_for_model( - config_namespace.model, - dataset_name=config_namespace.dataset, - tokenizer=tokenizer, - nsamples=config_namespace.nsamples, - seqlen=config_namespace.seqlen, - split="train", - seed=config_namespace.seed, - require_fx=require_fx, - device=None, - fuse_sequences=config_namespace.fuse_sequences, - ) - - validation_loader = get_dataset_for_model( - config_namespace.model, - dataset_name=config_namespace.dataset, - tokenizer=tokenizer, - nsamples=config_namespace.nsamples, - seqlen=config_namespace.seqlen, - split="validation", - seed=config_namespace.seed, - require_fx=require_fx, - device=None, - fuse_sequences=config_namespace.fuse_sequences, - ) - - device = next(iter(model.parameters())).device - print("Data loaded.") - - if config_namespace.eval: - assert config_namespace.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, - validation_loader, - context_length=config_namespace.seqlen // 2, - tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({config_namespace.dataset}): {float_ppl:.3f}") - - if require_fx: - model = get_fx(model) - - # Apply LN affine merging before inserting MHA layers - # since currently there is support only for merging into Linear - if config_namespace.ln_affine_merge: - print("Apply LN affine merge...") - apply_layernorm_affine_merge(model, dtype) - print("LN affine merge applied.") - - # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing - # with all the variability in HF implementations - if config_namespace.replace_mha: - print("Replace HF MHA with quantizable variants...") - model = replace_mha_with_quantizable_layers(model, dtype) - print("Replacing done.") - - if config_namespace.weight_equalization: - print("Apply weight equalization...") - # In case of float16 model, we need to offload to account for missing ops - model = offload_model(model) - apply_weight_equalization(model) - remove_hooks(model) - print("Weight equalization applied.") - - if config_namespace.act_equalization is not None: - offload_model(model) - print("Apply act equalization (SmoothQuant)...") - apply_act_equalization(model, config_namespace.act_equalization, calibration_loader) - print("Act equalization applied.") - remove_hooks(model) - - if not config_namespace.no_quantize: - name_blacklist = [] - print("Applying model quantization...") - linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers( - dtype=dtype, - weight_bit_width=config_namespace.weight_bit_width, - weight_param_method=config_namespace.weight_param_method, - weight_scale_precision=config_namespace.weight_scale_precision, - weight_quant_type=config_namespace.weight_quant_type, - weight_quant_granularity=config_namespace.weight_quant_granularity, - weight_group_size=config_namespace.weight_group_size, - weight_group_dim=config_namespace.weight_group_dim, - quantize_weight_zero_point=config_namespace.quantize_weight_zero_point, - weight_quant_format=config_namespace.weight_quant_format, - input_bit_width=config_namespace.input_bit_width, - input_quant_format=config_namespace.input_quant_format, - input_scale_precision=config_namespace.input_scale_precision, - input_scale_type=config_namespace.input_scale_type, - input_param_method=config_namespace.input_param_method, - input_quant_type=config_namespace.input_quant_type, - input_quant_granularity=config_namespace.input_quant_granularity, - input_group_size=config_namespace.input_group_size, - quantize_input_zero_point=config_namespace.quantize_input_zero_point, - device=device) - layer_map = generate_quant_maps( - linear_input_quant=linear_input_quant, - weight_quant=weight_quant, - input_quant=input_quant, - q_scaled_quant=q_scaled_quant, - k_transposed_quant=k_transposed_quant, - v_quant=v_quant, - attn_output_weights_quant=attn_output_weights_quant, - dtype=dtype, - device=device, - input_quant_format=config_namespace.input_quant_format, - quantize_embedding=False) - if not config_namespace.quantize_last_layer: - name_blacklist += ["lm_head", "embed_out"] - model = layerwise_quantize( - model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) - # Tie back first/last layer weights in case they got untied - print("Model quantization applied.") - - # If any equalization has taken places, the embedding layer and the fully connected one are - # not tied anymore, and they need to be treated as standalone, separate layers. - # In all other cases we can tie them back so to preserve memory. - if config_namespace.act_equalization is None and not require_fx: - model.tie_weights() - - if config_namespace.bias_corr: - model = add_zero_bias_to_linear(model) - - model = offload_model(model) - - if config_namespace.learned_round: - print("Applying learned round...") - learned_round_llm_utils = LearnedRoundLLMUtils() - learned_round = AutoRound() - learned_round_optimiser = LearnedRoundOptimizer( - learned_round=learned_round, - learned_round_utils=learned_round_llm_utils - ) - learned_round_optimiser.apply_learned_round(model, calibration_loader) - print("Learned round applied.") - - if config_namespace.act_calibration: - print("Apply act calibration...") - apply_calibration(model, calibration_loader) - print("Act calibration applied.") - - if config_namespace.gptq: - print("Applying GPTQ...") - apply_gptq( - model, - calibration_loader, - act_order=config_namespace.gpxq_act_order, - use_quant_activations=config_namespace.gpxq_use_quant_activations, - create_weight_orig=config_namespace.gpxq_create_weight_orig) - print("GPTQ applied.") - - if config_namespace.gpfq: - print("Applying GPFQ...") - apply_gpfq(model, calibration_loader, act_order=config_namespace.gpxq_act_order) - print("GPFQ applied.") - - if config_namespace.bias_corr: - print("Applying bias correction...") - apply_bias_correction(model, calibration_loader) - print("Bias correction applied.") - - if config_namespace.eval: - print("Model eval...") - quant_ppl = compute_perplexity( - model, - validation_loader, - context_length=config_namespace.seqlen // 2, - tokenizer=tokenizer) - print(f"Quantized perplexity ({config_namespace.dataset}): {quant_ppl:.3f}") - remove_hooks(model) - - # Validate the quant_model on the validation dataloader - print("Starting validation") - - column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys()] + [ - 'FP perplexity', 'Quant perplexity', 'Torch version', 'Brevitas version'] - values = [v for _, v in config_namespace.__dict__.items()] + [ - float_ppl, quant_ppl, torch_version, brevitas_version] - torchvision_df = pd.DataFrame([values], columns=column_names) - - folder = './multirun/' + str(args.idx) - os.makedirs(folder, exist_ok=True) - torchvision_df.to_csv(os.path.join(folder, 'RESULTS_LLM.csv'), index=False) - - -def validate_config(config_namespace): - is_valid = True - - if not config_namespace.no_quantize: - if config_namespace.gptq and config_namespace.gpfq: - is_valid = False - if config_namespace.export_target is not None: - if config_namespace.input_quant_format != 'int': - is_valid = False - if config_namespace.export_target is not None and config_namespace.input_bit_width is not None: - if config_namespace.input_scale_type != 'static': - is_valid = False - if config_namespace.export_target == 'sharded_torchmlir_group_weight': - if config_namespace.weight_quant_granularity != 'per_group': - is_valid = False - if config_namespace.input_bit_width is not None: - is_valid = False - if config_namespace.quantize_weight_zero_point: - is_valid = False - if config_namespace.export_target == 'sharded_packed_torchmlir_group_weight': - if config_namespace.weight_quant_granularity != 'per_group': - is_valid = False - if config_namespace.input_bit_width is not None: - is_valid = False - if config_namespace.quantize_weight_zero_point: - is_valid = False - if config_namespace.export_target == 'onnx_qcdq': - if config_namespace.weight_quant_granularity == 'per_group': - if config_namespace.input_bit_width is not None: - is_valid = False - if config_namespace.weight_quant_type == 'asym': - if not config_namespace.quantize_weight_zero_point: - is_valid = False - if config_namespace.input_bit_width is not None and config_namespace.input_quant_type == 'asym': - if not config_namespace.quantize_input_zero_point: - is_valid = False - if config_namespace.export_target == 'torch_qcdq': - if config_namespace.weight_quant_granularity == 'per_group': - is_valid = False - if config_namespace.weight_quant_type == 'asym': - if not config_namespace.quantize_weight_zero_point: - is_valid = False - if config_namespace.input_bit_width is not None and config_namespace.input_quant_type == 'asym': - if not config_namespace.quantize_input_zero_point: - is_valid = False - if config_namespace.input_bit_width and config_namespace.input_scale_type == 'static': - if not config_namespace.act_calibration: - is_valid = False - if (config_namespace.weight_equalization or config_namespace.act_equalization == 'fx'): - if config_namespace.replace_mha: - if config_namespace.export_target == 'onnx_qcdq': - is_valid = False - else: - if config_namespace.export_target == 'torch_qcdq': - is_valid = False - - config_namespace.is_valid = is_valid - return config_namespace + # Results are saved in CSV + column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys() + ] + [ + 'FP perplexity', 'Quant perplexity', 'Torch version', 'Brevitas version'] + values = [v for _, v in config_namespace.__dict__.items()] + [ + float_ppl, quant_ppl, torch_version, brevitas_version] + llm_df = pd.DataFrame([values], columns=column_names) + + folder = './multirun/' + str(args.idx) + os.makedirs(folder, exist_ok=True) + llm_df.to_csv(os.path.join(folder, 'RESULTS_LLM.csv'), index=False) + + except Exception as E: + print("Exception at index {}: {}".format(args.idx, E)) if __name__ == '__main__': diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 0bba58151..d4a7a6424 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -1,5 +1,44 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" import math import sys From d6ac8fd4681fe549be2d295efb1ff31d5f5092e7 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 08:49:46 +0000 Subject: [PATCH 09/48] Minor license change --- src/brevitas/optim/sign_sgd.py | 42 ++++++++++++++++++ .../llm/benchmark/llm_benchmark.py | 43 +------------------ 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index c34279b81..91a11476f 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -1,3 +1,45 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + # mypy: allow-untyped-defs from typing import List, Optional diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index 19efbafc6..043e9a683 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -1,44 +1,5 @@ -""" -Copyright (C) 2024, Advanced Micro Devices, Inc. -Copyright (c) 2016- Facebook, Inc (Adam Paszke) -Copyright (c) 2014- Facebook, Inc (Soumith Chintala) -Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) -Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) -Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) -Copyright (c) 2011-2013 NYU (Clement Farabet) -Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) -Copyright (c) 2006 Idiap Research Institute (Samy Bengio) -Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, - NEC Laboratories America and IDIAP Research Institute nor the names - of its contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. -""" +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause import argparse from functools import partial From 983eef29ba59c4790f92d042e099a47bbf48d858 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 09:55:19 +0000 Subject: [PATCH 10/48] Minor refactor in SignSGD --- src/brevitas/optim/sign_sgd.py | 98 ++++------------------------------ 1 file changed, 10 insertions(+), 88 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index 91a11476f..e9a0a1223 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -52,12 +52,13 @@ from torch.optim.optimizer import _maximize_doc from torch.optim.optimizer import _use_grad_for_differentiable from torch.optim.optimizer import Optimizer +from torch.optim.sgd import SGD from torch.utils._foreach_utils import _get_fused_kernels_supported_devices __all__ = ["SignSGD", "sign_sgd"] -class SignSGD(Optimizer): +class SignSGD(SGD): def __init__( self, @@ -73,14 +74,8 @@ def __init__( differentiable: bool = False, fused: Optional[bool] = None, ): - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr}") - if momentum < 0.0: - raise ValueError(f"Invalid momentum value: {momentum}") - if weight_decay < 0.0: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") - - defaults = dict( + super().__init__( + params=params, lr=lr, momentum=momentum, dampening=dampening, @@ -91,49 +86,6 @@ def __init__( differentiable=differentiable, fused=fused, ) - if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") - super().__init__(params, defaults) - - if fused: - self._step_supports_amp_scaling = True - - fused_supported_devices = _get_fused_kernels_supported_devices() - if not all(p.device.type in fused_supported_devices and torch.is_floating_point(p) - for pg in self.param_groups - for p in pg["params"]): - raise RuntimeError( - "`fused=True` requires all the params to be floating point Tensors of " - f"supported devices: {fused_supported_devices}.") - if differentiable: - raise RuntimeError("`fused` does not support `differentiable`") - if foreach: - raise RuntimeError("`fused` and `foreach` cannot be `True` together.") - - def __setstate__(self, state): - super().__setstate__(state) - for group in self.param_groups: - group.setdefault("nesterov", False) - group.setdefault("maximize", False) - group.setdefault("foreach", None) - group.setdefault("differentiable", False) - group.setdefault("fused", False) - - def _init_group(self, group, params, grads, momentum_buffer_list): - has_sparse_grad = False - - for p in group["params"]: - if p.grad is not None: - params.append(p) - grads.append(p.grad) - if p.grad.is_sparse: - has_sparse_grad = True - - if group["momentum"] != 0: - state = self.state[p] - momentum_buffer_list.append(state.get("momentum_buffer")) - - return has_sparse_grad @_use_grad_for_differentiable def step(self, closure=None): @@ -182,7 +134,7 @@ def step(self, closure=None): SignSGD.__doc__ = ( - r"""Implements stochastic gradient descent (optionally with momentum). + r"""Implements signed stochastic gradient descent (optionally with momentum). .. math:: \begin{aligned} @@ -206,9 +158,9 @@ def step(self, closure=None): &\hspace{10mm}\textbf{else} \\[-1.ex] &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma \text{sign}(g_t) \\[-1.ex] &\hspace{5mm}\textbf{else} \\[-1.ex] - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \text{sign}(g_t) \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] @@ -233,43 +185,13 @@ def step(self, closure=None): Example: >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer = torch.optim.SignSGD(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf - .. note:: - The implementation of SGD with Momentum/Nesterov subtly differs from - Sutskever et al. and implementations in some other frameworks. - - Considering the specific case of Momentum, the update can be written as - - .. math:: - \begin{aligned} - v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ - p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, - \end{aligned} - - where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the - parameters, gradient, velocity, and momentum respectively. - - This is in contrast to Sutskever et al. and - other frameworks which employ an update of the form - - .. math:: - \begin{aligned} - v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ - p_{t+1} & = p_{t} - v_{t+1}. - \end{aligned} - - The Nesterov version is analogously modified. - - Moreover, the initial value of the momentum buffer is set to the - gradient value at the first step. This is in contrast to some other - frameworks that initialize it to all zeros. - """) @@ -292,9 +214,9 @@ def sign_sgd( nesterov: bool, maximize: bool, ): - r"""Functional API that performs SGD algorithm computation. + r"""Functional API that performs Sign SGD algorithm computation. - See :class:`~torch.optim.SGD` for details. + See :class:`~torch.optim.SignSGD` for details. """ # Respect when the user inputs False/True for foreach or fused. We only want to change From 401f176d055e4824f3bfc6eba61b94758b1f9841 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 10:18:05 +0000 Subject: [PATCH 11/48] Include appropiate licensing --- .../learned_round/learned_round_optimizer.py | 92 +++++++++++-------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index f80cedb67..2ba90fded 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -1,5 +1,23 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Adapted from https://github.com/intel/auto-round, released under the following LICENSE: + +Copyright (c) 2023 Intel Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import ABC from abc import abstractmethod @@ -23,23 +41,25 @@ config.IGNORE_MISSING_KEYS = True + def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], bool]) -> List[nn.Module]: - blocks = [] - - # Iterating over .modules() might have been more readable but - # with this recursive implementation, once a block is reached, - # its subtree of modules is not expanded. - def _get_blocks(module: nn.Module): - for module_name, module_child in module.named_children(): - if block_check_fn(module_child, module_name): - blocks.append(module_child) - else: - _get_blocks(module_child) + blocks = [] + + # Iterating over .modules() might have been more readable but + # with this recursive implementation, once a block is reached, + # its subtree of modules is not expanded. + def _get_blocks(module: nn.Module): + for module_name, module_child in module.named_children(): + if block_check_fn(module_child, module_name): + blocks.append(module_child) + else: + _get_blocks(module_child) + + # Run recursive function that updates the list blocks + _get_blocks(model) + return blocks - # Run recursive function that updates the list blocks - _get_blocks(model) - return blocks class LearnedRoundModelUtils(ABC): @@ -99,6 +119,7 @@ def loss_scaler( ) -> torch.Tensor: pass + class LearnedRoundOptimizer: def __init__( @@ -114,17 +135,15 @@ def __init__( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, optimizer_kwargs: Dict = {}, - lr_scheduler_kwargs : Dict = { + lr_scheduler_kwargs: Dict = { "start_factor": 1.0, "end_factor": 0.0, - "verbose": False, - } + "verbose": False,} ) -> None: if learned_round.iters != iters: warnings.warn( "The number of iterations passed to the learned round optimiser is different " - "to that of the learned round method, which might lead to unexpected behaviour." - ) + "to that of the learned round method, which might lead to unexpected behaviour.") self.learned_round = learned_round self.learned_round_utils = learned_round_utils self.optimizer_class = optimizer_class @@ -166,12 +185,11 @@ def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: lr_scheduler.step() def apply_learned_round( - self, - model: nn.Module, - data_loader: DataLoader, - block_check_fn: Callable = None, - keep_gpu: bool = True - ) -> None: + self, + model: nn.Module, + data_loader: DataLoader, + block_check_fn: Callable = None, + keep_gpu: bool = True) -> None: # Prepare model for optimization self.learned_round_utils.init_model_learned_round(model) @@ -187,7 +205,7 @@ def apply_learned_round( # Loop across blocks to optimise rounding within each for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( - self.learned_round.learned_round_iterator(blocks)): + self.learned_round.learned_round_iterator(blocks)): # Block needs to be in eval mode while the rounding is optimised block.eval() @@ -196,17 +214,13 @@ def apply_learned_round( itertools.chain( *[ learned_round_module.parameters() - for learned_round_module in block_learned_round_modules - ] - ), + for learned_round_module in block_learned_round_modules]), lr=self.optimizer_lr, **self.optimizer_kwargs, ) lr_scheduler = ( self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) - if self.lr_scheduler_class - else None - ) + if self.lr_scheduler_class else None) # Variables needed for printing best_loss = torch.finfo(torch.float).max @@ -235,7 +249,8 @@ def apply_learned_round( quant_outs = self.learned_round_utils.run_forward(block, inputs) if self.use_amp: - with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=self.amp_dtype): + with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", + dtype=self.amp_dtype): loss, loss_components = block_loss(quant_outs, fp_outs) else: loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) @@ -252,10 +267,11 @@ def apply_learned_round( self._scale_loss_and_backward(loss) self._step(optimizer, lr_scheduler) - # Update progress bar + # Update progress bar pbar.set_description( "Block = {:d}/{:d}, {}".format( - block_idx + 1, len(blocks), + block_idx + 1, + len(blocks), block_loss.format_loss_components(*loss_components))) pbar.update(1) @@ -268,7 +284,7 @@ def apply_learned_round( print( f"Quantized block {block_idx+1}/{len(blocks)}, " - f"loss iter 0: {init_loss:.6f} -> iter {last_best_iter}: {best_loss:.6f}" + f"initial loss: {init_loss:.6f}, best loss: {best_loss:.6f}, at iteration {last_best_iter}." ) # Finish optimisation From a65058942046971b0151e9889c01b97c963c8787 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 17:30:03 +0000 Subject: [PATCH 12/48] Address comments --- .../core/function_wrapper/__init__.py | 2 +- .../core/function_wrapper/auto_round.py | 48 ----- .../core/function_wrapper/learned_round.py | 35 ++++ src/brevitas/quant/solver/common.py | 1 + .../learned_round/learned_round_method.py | 18 +- .../learned_round/learned_round_optimizer.py | 189 ++++++++++++++++-- .../llm/benchmark/llm_benchmark.py | 5 +- .../llm/llm_quant/learned_round_utils.py | 6 +- tests/brevitas/core/test_float_to_int.py | 2 +- .../test_learned_round_utils.py | 3 +- 10 files changed, 232 insertions(+), 77 deletions(-) delete mode 100644 src/brevitas/core/function_wrapper/auto_round.py diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index 4929026f3..d9aafa978 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -1,11 +1,11 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from .auto_round import AutoRoundSte from .clamp import ClampMin from .clamp import FloatClamp from .clamp import ScalarClamp from .clamp import TensorClamp +from .learned_round import AutoRoundSte from .learned_round import LearnedRoundSte from .misc import Identity from .misc import InplaceLogTwo diff --git a/src/brevitas/core/function_wrapper/auto_round.py b/src/brevitas/core/function_wrapper/auto_round.py deleted file mode 100644 index 7f1688291..000000000 --- a/src/brevitas/core/function_wrapper/auto_round.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -""" -Implementation of AutoRound -""" - -from typing import Optional - -import torch - -import brevitas -from brevitas import config -from brevitas.core.utils import SliceTensor -from brevitas.function.ops_ste import round_ste - - -class AutoRoundSte(brevitas.jit.ScriptModule): - """ - This Module implements AutoRound representation, where each weight has a learnable parameter - that decides if "ceil" or "floor" rounding type has to be used. - """ - - def __init__( - self, - learned_round_init: torch.Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: - super(AutoRoundSte, self).__init__() - learned_round_init = learned_round_init.to(device=device, dtype=dtype) - self.tensor_slicer = SliceTensor() - self.value = torch.nn.Parameter(learned_round_init) - - @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: - # p should be between [-0.5, 0.5], so this learnable parameter decides whether to "ceil" or "floor" - p = self.value - p = self.tensor_slicer(p) - return round_ste(x + (p.to(x.dtype)).view_as(x)) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - super(AutoRoundSte, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - value_key = prefix + 'value' - if config.IGNORE_MISSING_KEYS and value_key in missing_keys: - missing_keys.remove(value_key) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 55ef86a31..2ed008929 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -13,6 +13,7 @@ from brevitas import config from brevitas.core.utils import SliceTensor from brevitas.function.ops_ste import floor_ste +from brevitas.function.ops_ste import round_ste class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule): @@ -52,6 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return p +# TODO: Change name to AdaRoundSte for consistency class LearnedRoundSte(brevitas.jit.ScriptModule): """ This Module implements LearnedRound representation, where each weight has a learnable parameter @@ -92,3 +94,36 @@ def _load_from_state_dict( value_key = prefix + 'value' if config.IGNORE_MISSING_KEYS and value_key in missing_keys: missing_keys.remove(value_key) + + +class AutoRoundSte(brevitas.jit.ScriptModule): + """ + This Module implements AutoRound representation, where each weight has a learnable parameter + that decides if "ceil" or "floor" rounding type has to be used. + """ + + def __init__( + self, + learned_round_init: torch.Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> None: + super(AutoRoundSte, self).__init__() + learned_round_init = learned_round_init.to(device=device, dtype=dtype) + self.tensor_slicer = SliceTensor() + self.value = torch.nn.Parameter(learned_round_init) + + @brevitas.jit.script_method + def forward(self, x: torch.Tensor) -> torch.Tensor: + # p should be between [-0.5, 0.5], so this learnable parameter decides whether to "ceil" or "floor" + p = self.value + p = self.tensor_slicer(p) + return round_ste(x + (p.to(x.dtype)).view_as(x)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + super(AutoRoundSte, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + value_key = prefix + 'value' + if config.IGNORE_MISSING_KEYS and value_key in missing_keys: + missing_keys.remove(value_key) diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 509599764..568505b19 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -3,6 +3,7 @@ from brevitas.core.bit_width import * from brevitas.core.function_wrapper import * +from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index 07316ef72..51a6c190d 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -9,13 +9,18 @@ from torch import nn import torch.nn.functional as F -from brevitas.core.function_wrapper.auto_round import AutoRoundSte +from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.inject.enum import FloatToIntImplType from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +class StopFwdException(Exception): + """Used to throw and catch an exception to stop traversing the graph.""" + pass + + class LearnedRoundLoss(ABC): @abstractmethod @@ -26,6 +31,7 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, def format_loss_components(self, *args) -> str: pass + class LearnedRound(ABC): def __init__(self, iters: int = 200, **kwargs) -> None: @@ -76,6 +82,7 @@ def learned_round_iterator( block_loss = self._instantiate_loss(block, learned_round_modules) yield block, block_loss, learned_round_modules + class LinearTempDecay: def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): @@ -91,6 +98,7 @@ def __call__(self, t): rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) + class AdaRoundLoss(LearnedRoundLoss): def __init__( @@ -101,8 +109,7 @@ def __init__( max_count: int = 1000, b_range: Tuple = (20, 2), warmup: float = 0.2, - decay_start: float = 0.0 - ) -> None: + decay_start: float = 0.0) -> None: super().__init__() # AdaRound operates in a layer-wise manner, so integrity needs to be checked assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." @@ -139,6 +146,7 @@ def format_loss_components(self, loss: float, rec_loss: float, round_loss: float return "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( loss, rec_loss, round_loss, b) + class AdaRound(LearnedRound): def __init__( @@ -149,7 +157,7 @@ def __init__( warmup: float = 0.2, decay_start: float = 0.0, **kwargs, - ) -> None: + ) -> None: super().__init__(iters, **kwargs) # Loss-related configuration self.weight = weight @@ -187,6 +195,7 @@ def _instantiate_loss( decay_start=self.decay_start, ) + class AutoRoundLoss(LearnedRoundLoss): def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: @@ -196,6 +205,7 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, def format_loss_components(self, loss: float) -> str: return "loss = {:.4f}".format(loss) + class AutoRound(LearnedRound): def __init__(self, **kwargs) -> None: diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 2ba90fded..b754784cd 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -4,19 +4,182 @@ Adapted from https://github.com/intel/auto-round, released under the following LICENSE: -Copyright (c) 2023 Intel Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS """ from abc import ABC diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index 043e9a683..dec0e81c3 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -54,13 +54,12 @@ def unique(sequence): return [x for x in sequence if not (x in seen or seen.add(x))] -# Torchvision models with top1 accuracy -LLM_TOP1_MAP = { +LLM_PPL_MAP = { 'facebook/opt-125m': None, 'meta-llama/Llama-2-7b-hf': None,} OPTIONS_DEFAULT = { - 'model': list(LLM_TOP1_MAP.keys()), # HF model name. Default: facebook/opt-125m. + 'model': list(LLM_PPL_MAP.keys()), # HF model name. Default: facebook/opt-125m. 'seed': [0], # Seed for sampling the calibration data. Default: 0. 'nsamples': [128], # Number of calibration data samples. Default: 128. 'seqlen': [2048], # Sequence length. Default: 2048. diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 05ab8c191..cb976ce57 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -13,14 +13,10 @@ from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor +from brevitas_examples.common.learned_round.learned_round_method import StopFwdException from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils -class StopFwdException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - pass - - class LearnedRoundLLMUtils(LearnedRoundModelUtils): def __init__(self) -> None: diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index 1bcb79e47..0ec7f10e6 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -6,7 +6,7 @@ import torch from brevitas import config -from brevitas.core.function_wrapper.auto_round import AutoRoundSte +from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index 8999d4125..6c1deb143 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -9,9 +9,8 @@ from torch.utils.data import Dataset from brevitas import config -from brevitas.core.function_wrapper.auto_round import AutoRoundSte +from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundSte -from brevitas.inject.enum import FloatToIntImplType import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor.base_quant_tensor import QuantTensor From 806661df201b5b4840cade32da1590c53b64e5c0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 6 Nov 2024 10:35:02 +0000 Subject: [PATCH 13/48] Add missing change --- .../imagenet_classification/ptq/learned_round_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 9fb51a17a..5abe912c5 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -40,16 +40,12 @@ from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_method import StopFwdException from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils config.IGNORE_MISSING_KEYS = True -class StopFwdException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - pass - - class LearnedRoundVisionUtils(LearnedRoundModelUtils): def __init__(self) -> None: From 2de7515117a59840098c597abb96d75795a3d8bc Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 6 Nov 2024 17:15:53 +0000 Subject: [PATCH 14/48] Fix progress bar --- .../common/learned_round/learned_round_optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index b754784cd..7e9200959 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -437,6 +437,8 @@ def apply_learned_round( len(blocks), block_loss.format_loss_components(*loss_components))) pbar.update(1) + # Make sure no updates are received in the progress bar + pbar.close() if self.use_best_model: self._load_round_params(block, optimal_rounding_params) From ab97290ba25421d634f026aff0b81c638977181c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 7 Nov 2024 12:16:05 +0000 Subject: [PATCH 15/48] Minor improvements --- .../learned_round/learned_round_builder.py | 52 +++++++++++++++ .../learned_round/learned_round_method.py | 39 +++++++----- .../learned_round/learned_round_optimizer.py | 11 +++- .../ptq/ptq_evaluate.py | 63 ++++++++++--------- .../llm/llm_quant/learned_round_utils.py | 5 +- src/brevitas_examples/llm/main.py | 14 ++--- tests/brevitas_examples/test_imagenet.py | 3 + .../test_learned_round_utils.py | 3 + 8 files changed, 131 insertions(+), 59 deletions(-) create mode 100644 src/brevitas_examples/common/learned_round/learned_round_builder.py diff --git a/src/brevitas_examples/common/learned_round/learned_round_builder.py b/src/brevitas_examples/common/learned_round/learned_round_builder.py new file mode 100644 index 000000000..b8504355d --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_builder.py @@ -0,0 +1,52 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Dict + +from brevitas_examples.common.learned_round.learned_round_method import AdaRound +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ + LearnedRoundVisionUtils +from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils + + +def solve_learned_round_utils_cls(utils_type) -> LearnedRoundModelUtils: + if utils_type == "imagenet_classification": + return LearnedRoundVisionUtils + elif utils_type == "llm": + return LearnedRoundLLMUtils + else: + raise Exception(f"Learned round utilities for {utils_type} are not recognized.") + + +def solve_learned_round_method_cls(method_type) -> LearnedRound: + if method_type == "ada_round": + return AdaRound + elif method_type == "auto_round": + return AutoRound + else: + raise Exception(f"Learned round method {method_type} is not available.") + + +def instantiate_learned_round_optimizer( + utils_type: str, + method_type: str = "auto_round", + iters: int = 200, + method_params: Dict = {}, + optimizer_params: Dict = {}, + utils_params: Dict = {}) -> LearnedRoundOptimizer: + # Instantiate learned round utilities + learned_round_utils_cls = solve_learned_round_utils_cls(utils_type) + learned_round_utils = learned_round_utils_cls(**utils_params) + + # Instantiate learned round method + learned_round_method_cls = solve_learned_round_method_cls(method_type) + learned_round_method = learned_round_method_cls(iters, **method_params) + + # Make sure that the iterations of the learned round method and optimizer match + optimizer_params["iters"] = iters + # Instantiate optimizer + return LearnedRoundOptimizer(learned_round_method, learned_round_utils, **optimizer_params) diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index 51a6c190d..7c770b0cd 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -37,12 +37,15 @@ class LearnedRound(ABC): def __init__(self, iters: int = 200, **kwargs) -> None: self.iters = iters - def _insert_learned_round_quantizer(self, block: nn.Module) -> None: + def _insert_and_return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: + round_modules = [] for module in block.modules(): if isinstance(module, QuantWBIOL) and len( self._find_learned_round_modules(module)) == 0: self._insert_learned_round_quantizer_to_layer(module) module.weight_quant.init_tensor_quant(preserve_state_dict=True) + round_modules.append(module.weight_quant.tensor_quant.int_quant.float_to_int_impl) + return round_modules @abstractmethod def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: @@ -69,17 +72,17 @@ def learned_round_iterator( blocks: List[nn.Module]) -> Generator[nn.Module, LearnedRoundLoss, List[nn.Module]]: for block in blocks: # Insert learned round quantizers into the appropiate submodules - self._insert_learned_round_quantizer(block) + learned_round_modules = self._insert_and_return_learned_round_quantizers(block) # Freeze block parameters for params in block.parameters(): params.requires_grad = False - # Retrieve learned round modules - learned_round_modules = self._find_learned_round_modules(block) # Enable gradient tracking in learned round modules for round_module in learned_round_modules: for params in round_module.parameters(): params.requires_grad = True block_loss = self._instantiate_loss(block, learned_round_modules) + # Block needs to be in eval mode while the rounding is optimised + block.eval() yield block, block_loss, learned_round_modules @@ -152,6 +155,10 @@ class AdaRound(LearnedRound): def __init__( self, iters: int = 200, + *, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1, + learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, weight: float = 0.01, b_range: Tuple = (20, 2), warmup: float = 0.2, @@ -159,6 +166,10 @@ def __init__( **kwargs, ) -> None: super().__init__(iters, **kwargs) + # Quantiser-related configuration + self.learned_round_zeta = learned_round_zeta + self.learned_round_gamma = learned_round_gamma + self.learned_round_impl_type = learned_round_impl_type # Loss-related configuration self.weight = weight self.b_range = b_range @@ -168,20 +179,16 @@ def __init__( def _is_learned_round_module(self, module: nn.Module) -> bool: return isinstance(module, LearnedRoundSte) - def _insert_learned_round_quantizer_to_layer( - self, - layer: nn.Module, - learned_round_zeta: float = 1.1, - learned_round_gamma: float = -0.1) -> None: + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((learned_round_zeta - learned_round_gamma) / - (delta - learned_round_gamma) - 1) + value = -torch.log((self.learned_round_zeta - self.learned_round_gamma) / + (delta - self.learned_round_gamma) - 1) layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, - learned_round_gamma=learned_round_gamma, - learned_round_zeta=learned_round_zeta, + learned_round_impl_type=self.learned_round_impl_type, + learned_round_gamma=self.learned_round_gamma, + learned_round_zeta=self.learned_round_zeta, learned_round_init=value) def _instantiate_loss( @@ -208,8 +215,8 @@ def format_loss_components(self, loss: float) -> str: class AutoRound(LearnedRound): - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) + def __init__(self, iters: int = 200, **kwargs) -> None: + super().__init__(iters, **kwargs) def _is_learned_round_module(self, module: nn.Module) -> bool: return isinstance(module, AutoRoundSte) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 7e9200959..16b97f806 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -289,6 +289,7 @@ def __init__( self, learned_round: LearnedRound, learned_round_utils: LearnedRoundModelUtils, + *, optimizer_class: Optimizer = SignSGD, lr_scheduler_class: LRScheduler = LinearLR, optimizer_lr: float = 5e-3, @@ -369,9 +370,6 @@ def apply_learned_round( # Loop across blocks to optimise rounding within each for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( self.learned_round.learned_round_iterator(blocks)): - # Block needs to be in eval mode while the rounding is optimised - block.eval() - # Initialise optimiser and LR scheduler optimizer = self.optimizer_class( itertools.chain( @@ -401,6 +399,9 @@ def apply_learned_round( data_loader, keep_gpu=keep_gpu, ) + # Enable training model in quantizer modules + for learned_round_module in block_learned_round_modules: + learned_round_module.train() pbar = tqdm(range(self.iters), desc='') for i in pbar: @@ -440,6 +441,10 @@ def apply_learned_round( # Make sure no updates are received in the progress bar pbar.close() + # Set back quantizers to eval mode + for learned_round_module in block_learned_round_modules: + learned_round_module.eval() + if self.use_best_model: self._load_round_params(block, optimal_rounding_params) else: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 240d08756..fb2a06e79 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -22,6 +22,8 @@ from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.common.learned_round.learned_round_builder import \ + instantiate_learned_round_optimizer from brevitas_examples.common.learned_round.learned_round_method import AdaRound from brevitas_examples.common.learned_round.learned_round_method import AutoRound from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer @@ -166,9 +168,10 @@ def validate_args(args): type=int, help='Numbers of iterations for graph equalization (default: 20)') parser.add_argument( - '--learned-round-type', - default='none', - choices=['none', 'ada_round', 'auto_round'], + '--learned-round', + default=None, + type=str, + choices=[None, 'ada_round', 'auto_round'], help='Learned round type (default: none)') parser.add_argument( '--learned-round-mode', @@ -347,7 +350,7 @@ def main(): f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" f"{'gpxq_act_order_' if args.gpxq_act_order else ''}" - f"{'learned_round_type' if args.learned_round_type != "none" else ''}" + f"{'learned_round' if args.learned_round is not None else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" f"{args.bias_bit_width}bias_" f"{args.weight_quant_granularity}_" @@ -369,7 +372,8 @@ def main(): f"GPTQ: {args.gptq} - " f"GPFQ: {args.gpfq} - " f"GPxQ Act Order: {args.gpxq_act_order} - " - f"Learned Round type: {args.learned_round_type} - " + f"GPxQ Accumulator Bit Width: {args.gpxq_accumulator_bit_width} - " + f"Learned Round method: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " f"Weight scale factors type: {args.weight_quant_granularity} - " @@ -423,7 +427,7 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) elif args.target_backend == 'fx' or args.target_backend == 'layerwise': - if args.learned_round_type != "auto_round": + if args.learned_round != "auto_round": model = preprocess_for_quantize( model, equalize_iters=args.graph_eq_iterations, @@ -434,11 +438,7 @@ def main(): else: raise RuntimeError(f"{args.target_backend} backend not supported.") - device = ( - torch.device(f"cuda:{args.gpu}") - if args.gpu is not None - else torch.device("cpu") - ) + device = (torch.device(f"cuda:{args.gpu}") if args.gpu is not None else torch.device("cpu")) model = model.to(device=device) # If available, use the selected GPU if args.gpu is not None: @@ -505,13 +505,9 @@ def main(): max_accumulator_bit_width=args.gpxq_accumulator_bit_width, max_accumulator_tile_size=args.gpxq_accumulator_tile_size) - if args.learned_round_type != "none": - # Initialisation of rounding method - if args.learned_round_type =="auto_round": - learned_round = AutoRound(iters=args.learned_round_iters) - elif args.learned_round_type == "ada_round": - learned_round = AdaRound(iters=args.learned_round_iters) - # Optimizer to tune the + if args.learned_round: + print("Applying Learned Round:") + # Optimizer to tune the rounding if args.optimizer == "adam": optimizer_class = torch.optim.Adam elif args.optimizer == "sign_sgd": @@ -523,21 +519,26 @@ def main(): block_check_fn = _is_layer elif args.learned_round_mode == "blockwise": block_check_fn = _is_resnet_block - - learned_round_vision_utils = LearnedRoundVisionUtils() - learned_round_optimiser = LearnedRoundOptimizer( - learned_round=learned_round, - learned_round_utils=learned_round_vision_utils, - optimizer_class=optimizer_class, - lr_scheduler_class= None if args.optimizer == "adam" else torch.optim.lr_scheduler.LinearLR, - optimizer_lr=args.learned_round_lr, - batch_size=args.learned_round_batch_size, + # Instantiate optimizer + learned_round_optimizer = instantiate_learned_round_optimizer( + utils_type="imagenet_classification", + method_type=args.learned_round, iters=args.learned_round_iters, - ) - learned_round_optimiser.apply_learned_round( + optimizer_params={ + "optimizer_lr": + args.learned_round_lr, + "optimizer_class": + optimizer_class, + "lr_scheduler_class": + None if args.optimizer == "adam" else torch.optim.lr_scheduler.LinearLR, + "batch_size": + args.learned_round_batch_size, + "use_best_model": + False if args.learned_round == "ada_round" else True,}) + learned_round_optimizer.apply_learned_round( model, - data_loader=calib_loader, - block_check_fn=block_check_fn + calib_loader, + block_check_fn=block_check_fn, ) if args.calibrate_bn: diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index cb976ce57..223035e53 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -19,9 +19,10 @@ class LearnedRoundLLMUtils(LearnedRoundModelUtils): - def __init__(self) -> None: + def __init__(self, loss_scaling_factor: float = 1000.) -> None: super(LearnedRoundLLMUtils, self).__init__() self.llm_cache_state = None + self.loss_scaling_factor = loss_scaling_factor def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) @@ -210,4 +211,4 @@ def loss_scaler( self, loss: torch.Tensor, ) -> torch.Tensor: - return loss * 1000 + return loss * self.loss_scaling_factor diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index d62212e07..b426c6ecf 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -23,6 +23,8 @@ from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers +from brevitas_examples.common.learned_round.learned_round_builder import \ + instantiate_learned_round_optimizer from brevitas_examples.common.learned_round.learned_round_method import AutoRound from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.parse_utils import quant_format_validator @@ -372,13 +374,11 @@ def main(args): if args.learned_round: print("Applying learned round...") - - learned_round_llm_utils = LearnedRoundLLMUtils() - learned_round = AutoRound() - learned_round_optimiser = LearnedRoundOptimizer( - learned_round=learned_round, learned_round_utils=learned_round_llm_utils) - learned_round_optimiser.apply_learned_round(model, calibration_loader) - + learned_round_optimizer = instantiate_learned_round_optimizer( + utils_type="llm", + method_type=args.learned_round, + ) + learned_round_optimizer.apply_learned_round(model, calibration_loader) print("Learned round applied.") if args.act_calibration: diff --git a/tests/brevitas_examples/test_imagenet.py b/tests/brevitas_examples/test_imagenet.py index 4d7afdc7c..e4e117f1a 100644 --- a/tests/brevitas_examples/test_imagenet.py +++ b/tests/brevitas_examples/test_imagenet.py @@ -1,3 +1,6 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + from hypothesis import given import pytest import pytest_cases diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index 6c1deb143..8c2a62c14 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -1,3 +1,6 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + from hypothesis import given import pytest import pytest_cases From ea09841a9e848e96871b1aa9b8149a419fdc259e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 11 Nov 2024 11:24:17 +0000 Subject: [PATCH 16/48] Remove utils hierarchy --- src/brevitas/optim/sign_sgd.py | 110 ++--- .../learned_round/learned_round_builder.py | 52 --- .../learned_round/learned_round_optimizer.py | 186 ++++++-- .../ptq/learned_round_utils.py | 441 ++++++------------ .../imagenet_classification/ptq/ptq_common.py | 22 - .../ptq/ptq_evaluate.py | 50 +- .../llm/llm_quant/learned_round_utils.py | 307 +++++------- src/brevitas_examples/llm/main.py | 12 +- 8 files changed, 459 insertions(+), 721 deletions(-) delete mode 100644 src/brevitas_examples/common/learned_round/learned_round_builder.py diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index e9a0a1223..bd26b40d4 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -59,6 +59,54 @@ class SignSGD(SGD): + """Implements signed stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma \text{sign}(g_t) \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \text{sign}(g_t) \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + Example: + >>> optimizer = torch.optim.SignSGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + """ def __init__( self, @@ -133,68 +181,6 @@ def step(self, closure=None): return loss -SignSGD.__doc__ = ( - r"""Implements signed stochastic gradient descent (optionally with momentum). - - .. math:: - \begin{aligned} - &\rule{110mm}{0.4pt} \\ - &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) - \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ - &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, - \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] - &\rule{110mm}{0.4pt} \\ - &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ - &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ - &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ - &\hspace{10mm}\textbf{if} \: t > 1 \\ - &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ - &\hspace{10mm}\textbf{else} \\ - &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ - &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ - &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ - &\hspace{10mm}\textbf{else} \\[-1.ex] - &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ - &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma \text{sign}(g_t) \\[-1.ex] - &\hspace{5mm}\textbf{else} \\[-1.ex] - &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \text{sign}(g_t) \\[-1.ex] - &\rule{110mm}{0.4pt} \\[-1.ex] - &\bf{return} \: \theta_t \\[-1.ex] - &\rule{110mm}{0.4pt} \\[-1.ex] - \end{aligned} - - Nesterov momentum is based on the formula from - `On the importance of initialization and momentum in deep learning`__. - """ + rf""" - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - dampening (float, optional): dampening for momentum (default: 0) - nesterov (bool, optional): enables Nesterov momentum (default: False) - {_maximize_doc} - {_foreach_doc} - {_differentiable_doc} - {_fused_doc} - """ + r""" - - Example: - >>> # xdoctest: +SKIP - >>> optimizer = torch.optim.SignSGD(model.parameters(), lr=0.1, momentum=0.9) - >>> optimizer.zero_grad() - >>> loss_fn(model(input), target).backward() - >>> optimizer.step() - - __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf - - """) - - def sign_sgd( params: List[Tensor], d_p_list: List[Tensor], diff --git a/src/brevitas_examples/common/learned_round/learned_round_builder.py b/src/brevitas_examples/common/learned_round/learned_round_builder.py deleted file mode 100644 index b8504355d..000000000 --- a/src/brevitas_examples/common/learned_round/learned_round_builder.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Dict - -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound -from brevitas_examples.common.learned_round.learned_round_method import LearnedRound -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - LearnedRoundVisionUtils -from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils - - -def solve_learned_round_utils_cls(utils_type) -> LearnedRoundModelUtils: - if utils_type == "imagenet_classification": - return LearnedRoundVisionUtils - elif utils_type == "llm": - return LearnedRoundLLMUtils - else: - raise Exception(f"Learned round utilities for {utils_type} are not recognized.") - - -def solve_learned_round_method_cls(method_type) -> LearnedRound: - if method_type == "ada_round": - return AdaRound - elif method_type == "auto_round": - return AutoRound - else: - raise Exception(f"Learned round method {method_type} is not available.") - - -def instantiate_learned_round_optimizer( - utils_type: str, - method_type: str = "auto_round", - iters: int = 200, - method_params: Dict = {}, - optimizer_params: Dict = {}, - utils_params: Dict = {}) -> LearnedRoundOptimizer: - # Instantiate learned round utilities - learned_round_utils_cls = solve_learned_round_utils_cls(utils_type) - learned_round_utils = learned_round_utils_cls(**utils_params) - - # Instantiate learned round method - learned_round_method_cls = solve_learned_round_method_cls(method_type) - learned_round_method = learned_round_method_cls(iters, **method_params) - - # Make sure that the iterations of the learned round method and optimizer match - optimizer_params["iters"] = iters - # Instantiate optimizer - return LearnedRoundOptimizer(learned_round_method, learned_round_utils, **optimizer_params) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 16b97f806..c18a27c8e 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -186,9 +186,10 @@ from abc import abstractmethod import copy import itertools -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings +from accelerate.utils.operations import send_to_device import torch from torch import autocast from torch import nn @@ -199,6 +200,9 @@ from tqdm import tqdm from brevitas import config +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.optim.sign_sgd import SignSGD from brevitas_examples.common.learned_round.learned_round_method import LearnedRound @@ -224,63 +228,68 @@ def _get_blocks(module: nn.Module): return blocks -class LearnedRoundModelUtils(ABC): +class StopFwdException(Exception): + """Used to throw and catch an exception to stop traversing the graph.""" + pass - def __init__(self) -> None: - pass + +class Cache(ABC): @abstractmethod - def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: + def __len__(self) -> int: pass @abstractmethod - def init_model_learned_round(self, model: nn.Module) -> None: + def store_inputs(self, args: Any, kwargs: Any) -> None: pass @abstractmethod - def finish_model_learned_round(self, model: nn.Module) -> None: + def store_output(self, output: Any) -> None: pass @abstractmethod - def init_cache(self) -> Any: + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: pass @abstractmethod - def populate_cache( - self, - cache: Any, - model: nn.Module, - block: nn.Module, - data_loader: DataLoader, - keep_gpu: bool = True, - **kwargs, - ) -> int: + def initialize_cache(self) -> None: pass @abstractmethod - def sample_cache( - self, - block: nn.Module, - cache: Any, - indices: torch.Tensor, - **kwargs, - ) -> Tuple[Any, torch.Tensor]: + def clear_cache(self) -> None: pass @abstractmethod - def run_forward( - self, - block: nn.Module, - inputs: Any, - ) -> torch.Tensor: + def reset_cache(self) -> None: pass - @abstractmethod - def loss_scaler( + +class DataSaverHook: + + def __init__( self, - loss: torch.Tensor, - ) -> torch.Tensor: - pass + cache: Cache, + store_inputs: bool = True, + store_output: bool = True, + keep_gpu: bool = True, + ) -> None: + self.cache = cache + self.store_inputs = store_inputs + self.store_output = store_output + self.keep_gpu = keep_gpu + + def __call__(self, module, args, kwargs, output) -> None: + if self.store_inputs: + if not self.keep_gpu: + args = send_to_device(args, 'cpu') + kwargs = send_to_device(kwargs, 'cpu') + self.cache.store_inputs(args, kwargs) + if self.store_output: + if not self.keep_gpu: + output = send_to_device(output, 'cpu') + self.cache.store_output(output) + + raise StopFwdException class LearnedRoundOptimizer: @@ -288,7 +297,6 @@ class LearnedRoundOptimizer: def __init__( self, learned_round: LearnedRound, - learned_round_utils: LearnedRoundModelUtils, *, optimizer_class: Optimizer = SignSGD, lr_scheduler_class: LRScheduler = LinearLR, @@ -298,7 +306,8 @@ def __init__( use_best_model: bool = True, use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, - optimizer_kwargs: Dict = {}, + loss_scaling_factor: float = 1000., + optimizer_kwargs: Dict = None, lr_scheduler_kwargs: Dict = { "start_factor": 1.0, "end_factor": 0.0, @@ -309,7 +318,6 @@ def __init__( "The number of iterations passed to the learned round optimiser is different " "to that of the learned round method, which might lead to unexpected behaviour.") self.learned_round = learned_round - self.learned_round_utils = learned_round_utils self.optimizer_class = optimizer_class self.lr_scheduler_class = lr_scheduler_class self.optimizer_lr = optimizer_lr @@ -318,6 +326,7 @@ def __init__( self.use_best_model = use_best_model self.use_amp = use_amp self.amp_dtype = amp_dtype + self.loss_scaling_factor = loss_scaling_factor self.optimizer_kwargs = optimizer_kwargs self.lr_scheduler_kwargs = lr_scheduler_kwargs @@ -338,7 +347,7 @@ def _collect_round_params(self, block: nn.Module) -> Dict: return params def _scale_loss_and_backward(self, loss: torch.Tensor) -> torch.Tensor: - scaled_loss = self.learned_round_utils.loss_scaler(loss) + scaled_loss = loss * self.loss_scaling_factor scaled_loss.backward() return scaled_loss @@ -348,16 +357,88 @@ def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: if lr_scheduler: lr_scheduler.step() + def _save_inputs_output( + self, + model: nn.Module, + model_forward: Callable, + module: nn.Module, + dataloader: DataLoader, + cache: Cache, + store_inputs: bool = True, + store_output: bool = False, + keep_gpu: bool = True, + disable_quant: bool = False) -> None: + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + data_saver = DataSaverHook( + cache, store_inputs=store_inputs, store_output=store_output, keep_gpu=keep_gpu) + handle = module.register_forward_hook(data_saver, with_kwargs=True) + with torch.no_grad(): + for inps in dataloader: + try: + model_forward(model, inps) + except StopFwdException: + pass + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + + def _populate_cache( + self, + cache: Cache, + model: nn.Module, + model_forward: nn.Module, + block: nn.Module, + data_loader: DataLoader, + keep_gpu: bool = True, + capture_quant_input: bool = True, + capture_quant_output: bool = False, + ) -> None: + # Populate the cache with new inputs and outputs + self._save_inputs_output( + model, + model_forward, + block, + data_loader, + cache, + store_inputs=True, + store_output=capture_quant_input == capture_quant_output, + keep_gpu=keep_gpu, + disable_quant=not capture_quant_input, + ) + if capture_quant_input != capture_quant_output: + self._save_inputs_output( + model, + model_forward, + block, + data_loader, + cache, + store_inputs=False, + store_output=True, + keep_gpu=keep_gpu, + disable_quant=not capture_quant_output, + ) + def apply_learned_round( self, model: nn.Module, + model_forward: Callable, + block_forward: Callable, data_loader: DataLoader, - block_check_fn: Callable = None, + cache: Cache, + block_check_fn: Callable, + model_prepare_fn: Optional[Callable] = None, + model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True) -> None: - # Prepare model for optimization - self.learned_round_utils.init_model_learned_round(model) - block_check_fn = block_check_fn if block_check_fn else self.learned_round_utils.default_block_check_fn + model_dict = None if model_prepare_fn is None else model_prepare_fn(model) + # Retrieve blocks using the appropiate function to check blocks blocks = get_blocks(model, block_check_fn) @@ -365,7 +446,7 @@ def apply_learned_round( print(f"Number of blocks {len(blocks)}") # Initialise cache to store partial inputs and outputs for each block - cache = self.learned_round_utils.init_cache() + cache.initialize_cache() # Loop across blocks to optimise rounding within each for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( @@ -390,15 +471,21 @@ def apply_learned_round( optimal_rounding_params = {} + cache.clear_cache() torch.cuda.empty_cache() # Populate cache for the given block - n_samples = self.learned_round_utils.populate_cache( + self._populate_cache( cache, model, + model_forward, block, data_loader, keep_gpu=keep_gpu, + capture_quant_input=True, + capture_quant_output=False, ) + # Retrieve number of samples + n_samples = len(cache) # Enable training model in quantizer modules for learned_round_module in block_learned_round_modules: learned_round_module.train() @@ -407,10 +494,10 @@ def apply_learned_round( for i in pbar: # Sample mini-batch from cache idxs = torch.randperm(n_samples)[:self.batch_size] - inputs, fp_outs = self.learned_round_utils.sample_cache(block, cache, idxs) + inputs, fp_outs = cache.sample_batch(idxs) # Run block forward to obtain quant outputs - quant_outs = self.learned_round_utils.run_forward(block, inputs) + quant_outs = block_forward(block, inputs) if self.use_amp: with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", @@ -437,9 +524,10 @@ def apply_learned_round( block_idx + 1, len(blocks), block_loss.format_loss_components(*loss_components))) - pbar.update(1) # Make sure no updates are received in the progress bar pbar.close() + # Reset cache for other blocks + cache.reset_cache() # Set back quantizers to eval mode for learned_round_module in block_learned_round_modules: @@ -456,6 +544,6 @@ def apply_learned_round( f"Quantized block {block_idx+1}/{len(blocks)}, " f"initial loss: {init_loss:.6f}, best loss: {best_loss:.6f}, at iteration {last_best_iter}." ) - # Finish optimisation - self.learned_round_utils.finish_model_learned_round(model) + if model_finish_fn is not None: + model_finish_fn(model, model_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 5abe912c5..2ba90e9df 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -27,7 +27,8 @@ # SOFTWARE. import re -from typing import Any, Tuple +from typing import Any, Callable, Tuple, Union +import warnings from accelerate.utils.operations import send_to_device import torch @@ -35,313 +36,149 @@ from torch.utils.data.dataloader import DataLoader from brevitas import config -from brevitas.graph.calibrate import disable_return_quant_tensor -from brevitas.graph.calibrate import DisableEnableQuantization -from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +from brevitas.optim.sign_sgd import SignSGD from brevitas.quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import StopFwdException -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils +from brevitas_examples.common.learned_round.learned_round_method import AdaRound +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer config.IGNORE_MISSING_KEYS = True -class LearnedRoundVisionUtils(LearnedRoundModelUtils): +class CacheCNN(dict): def __init__(self) -> None: - pass - - def init_model_learned_round(self, model: nn.Module) -> None: - pass - - def finish_model_learned_round(self, model: nn.Module) -> None: - pass - - def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: - return (re.search(r"layer\d+", module_name) is not None) - - class _DataSaverHook: - - def __init__(self, store_output: False): - self.store_output = store_output - self.input_store = None - self.output_store = None - - def __call__(self, module, input_batch, output_batch): - input_batch = input_batch[0] - if isinstance(input_batch, QuantTensor): - input_batch = input_batch.value - - if hasattr(input_batch, 'names') and 'N' in input_batch.names: - batch_dim = input_batch.names.index('N') - - input_batch.rename_(None) - input_batch = input_batch.transpose(0, batch_dim) - if self.store_output: - output_batch.rename_(None) - output_batch = output_batch.transpose(0, batch_dim) - - if self.store_output: - self.output_store = output_batch - self.input_store = input_batch - raise StopFwdException - - def _save_inp_out_data( - self, - model: nn.Module, - module: nn.Module, - learned_round_modules: List[nn.Module], - weight: float = 0.01, - max_count: int = 1000, - b_range: Tuple = (20, 2), - warmup: float = 0.2, - decay_start: float = 0.0) -> None: super().__init__() - # AdaRound operates in a layer-wise manner, so integrity needs to be checked - assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." - assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." - - self.weight = weight - self.module = module - self.loss_start = max_count * warmup - self.temp_decay = LinearTempDecay( - max_count, - start_b=b_range[0], - end_b=b_range[1], - rel_start_decay=warmup + (1.0 - warmup) * decay_start) - self.iter = 0 - self.learned_round_module = learned_round_modules[0] - - def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: - self.iter += 1 - - rec_loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() - - if self.iter < self.loss_start: - b = self.temp_decay(self.iter) - round_loss = 0 - else: # 1 - |(h-0.5)*2|**b - b = self.temp_decay(self.iter) - round_vals = self.learned_round_module.p_forward() - round_loss = self.weight * (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum() - - total_loss = rec_loss + round_loss - return total_loss, (total_loss, rec_loss, round_loss, b) - - def format_loss_components(self, loss: float, rec_loss: float, round_loss: float, b) -> str: - return "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( - loss, rec_loss, round_loss, b) - - -class AutoRoundLoss(LearnedRoundLoss): - - def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: - loss = F.mse_loss(pred, tgt, reduction='none').sum(1).mean() - return loss, (loss,) - - def format_loss_components(self, loss: float) -> str: - return "loss = {:.4f}".format(loss) - - -class LearnedRound(ABC): - - def __init__(self, iters: int = 100) -> None: - self.iters = iters - - def _insert_learned_round_quantizer(self, block: nn.Module) -> None: - for module in block.modules(): - if isinstance(module, QuantWBIOL) and len( - self._find_learned_round_modules(module)) == 0: - self._insert_learned_round_quantizer_to_layer(module) - module.weight_quant.init_tensor_quant(preserve_state_dict=True) - - @abstractmethod - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - pass - - @abstractmethod - def _is_learned_round_module(self, module: nn.Module) -> bool: - pass - - @abstractmethod - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> LearnedRoundLoss: - pass - - def _find_learned_round_modules(self, block: nn.Module) -> List[nn.Module]: - round_modules = [] - for module in block.modules(): - if self._is_learned_round_module(module): - round_modules.append(module) - return round_modules - - def learned_round_iterator( - self, - blocks: List[nn.Module]) -> Generator[nn.Module, LearnedRoundLoss, List[nn.Module]]: - for block in blocks: - # Insert learned round quantizers into the appropiate submodules - self._insert_learned_round_quantizer(block) - # Freeze block parameters - for params in block.parameters(): - params.requires_grad = False - # Retrieve learned round modules - learned_round_modules = self._find_learned_round_modules(block) - # Enable gradient tracking in learned round modules - for round_module in learned_round_modules: - for params in round_module.parameters(): - params.requires_grad = True - block_loss = self._instantiate_loss(block, learned_round_modules) - yield block, block_loss, learned_round_modules - block.eval() - - -class AdaRound(LearnedRound): - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - - def _is_learned_round_module(self, module: nn.Module) -> bool: - return isinstance(module, LearnedRoundSte) - - def _insert_learned_round_quantizer_to_layer( - self, - layer: nn.Module, - learned_round_zeta: float = 1.1, - learned_round_gamma: float = -0.1) -> None: - floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) - delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((learned_round_zeta - learned_round_gamma) / - (delta - learned_round_gamma) - 1) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID, - learned_round_gamma=learned_round_gamma, - learned_round_zeta=learned_round_zeta, - learned_round_init=value) - - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AdaRoundLoss: - return AdaRoundLoss(block, learned_round_modules, max_count=self.iters) - - -class AutoRound(LearnedRound): - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - - def _is_learned_round_module(self, module: nn.Module) -> bool: - return isinstance(module, AutoRoundSte) - - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - value = torch.zeros_like(layer.weight.data) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, - learned_round_init=value, - ) - - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AutoRoundLoss: - return AutoRoundLoss() - - - def _save_inp_out_data( - self, - model: nn.Module, - module: nn.Module, - dataloader: DataLoader, - store_inp: bool = False, - store_out: bool = False, - keep_gpu: bool = True, - disable_quant: bool = False): - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) - - device = next(model.parameters()).device - data_saver = LearnedRoundVisionUtils._DataSaverHook(store_output=store_out) - handle = module.register_forward_hook(data_saver) - cached = [[], []] - with torch.no_grad(): - for img, t in dataloader: - try: - _ = model(img.to(device)) - except StopFwdException: - pass - if store_inp: - if keep_gpu: - cached[0].append(data_saver.input_store.detach()) - else: - cached[0].append(data_saver.input_store.detach().cpu()) - if store_out: - if keep_gpu: - cached[1].append(data_saver.output_store.detach()) - else: - cached[1].append(data_saver.output_store.detach().cpu()) - if store_inp: - cached[0] = torch.cat([x for x in cached[0]], dim=0) - if store_out: - cached[1] = torch.cat([x for x in cached[1]], dim=0) - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - return cached - - def init_cache(self) -> Any: - return [], [] - - def populate_cache( - self, - cache: Any, - model: nn.Module, - block: nn.Module, - data_loader: DataLoader, - keep_gpu: bool = True, - **kwargs, - ) -> int: - cache_input, cache_output = cache - # Clear caches - cache_input.clear() - cache_output.clear() - - _, all_fp_out = self._save_inp_out_data(model, block, data_loader, store_inp=False, store_out=True, keep_gpu=keep_gpu, disable_quant=True) - all_quant_inp, _ = self._save_inp_out_data(model, block, data_loader, store_inp=True, store_out=True, keep_gpu=keep_gpu, disable_quant=False) - - # Add elements to the caches - cache_input.append(all_quant_inp) - cache_output.append(all_fp_out) - - # Number of samples - return all_fp_out.shape[0] - - def sample_cache( - self, - block: nn.Module, - cache: Any, - indices: torch.Tensor, - **kwargs, - ) -> Tuple[Any, torch.Tensor]: - cache_input, cache_output = cache - device = next(block.parameters()).device - - input, output = cache_input[0][indices], cache_output[0][indices] - input = send_to_device(input, device) - output = send_to_device(output, device) - - return input, output - - def run_forward( - self, - block: nn.Module, - inputs: Any, - ) -> torch.Tensor: - return block(inputs) - - def loss_scaler( - self, - loss: torch.Tensor, - ) -> torch.Tensor: - return loss + self.batch_dim = 0 + + def store_inputs(self, args, kwargs) -> None: + input_batch = args[0] + if isinstance(input_batch, QuantTensor): + input_batch = input_batch.value + + if hasattr(input_batch, 'names') and 'N' in input_batch.names: + self.batch_dim = input_batch.names.index('N') + input_batch.rename_(None) + input_batch = input_batch.transpose(0, self.batch_dim) + + self["inputs"].append(input_batch) + + def store_output(self, output) -> None: + if self.batch_dim is not None: + output.rename_(None) + output = output.transpose(0, self.batch_dim) + + self["output"].append(output) + + def initialize_cache(self) -> None: + self["inputs"] = [] + self["output"] = [] + + def clear_cache(self) -> None: + del self["inputs"] + del self["output"] + self["inputs"] = [] + self["output"] = [] + + def reset_cache(self) -> None: + del self["inputs"] + del self["output"] + self["inputs"] = [] + self["output"] = [] + + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + if isinstance(self["inputs"], list): + self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim) + if isinstance(self["output"], list): + self["output"] = torch.cat(self["output"], dim=self.batch_dim) + + return self["inputs"][indices], self["output"][indices] + + def __len__(self): + return ( + len(self["inputs"]) + if isinstance(self["inputs"], list) else self["inputs"].shape[self.batch_dim]) + + +def cnn_forward(model: nn.Module, inputs: Any) -> None: + device = next(model.parameters()).device + img, _ = inputs + img = send_to_device(img, device) + model(img) + + +def cnn_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: + device = next(block.parameters()).device + inputs = send_to_device(inputs, device) + return block(inputs) + + +def is_resnet_block(module: nn.Module, module_name: str) -> bool: + return (re.search(r"layer\d+", module_name) is not None) + + +def is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + + +def apply_learned_round( + model: nn.Module, + calibration_loader: DataLoader, + learned_round_name: str = "ada_round", + optimizer: str = "adam", + learned_round_mode: str = "layerwise", + iters: int = 1000, + optimizer_lr: float = 1e-3, + batch_size: int = 1, +) -> None: + optimizer_classes = {"adam": torch.optim.Adam, "sign_sgd": SignSGD} + if optimizer not in optimizer_classes: + raise ValueError(f"{optimizer} is not a valid optimizer.") + optimizer_class = optimizer_classes[optimizer] + + block_check_fns = {"layerwise": is_layer, "blockwise": is_resnet_block} + if learned_round_mode not in block_check_fns: + learned_round_mode = "layerwise" + warnings.warn( + f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") + block_check_fn = block_check_fns[learned_round_mode] + + learned_round_methods = {"ada_round": AdaRound, "auto_round": AutoRound} + if learned_round_name not in learned_round_methods: + raise ValueError(f"Learned round method {learned_round_name} is not available.") + learned_round = learned_round_methods[learned_round_name](iters=iters) + + lr_scheduler_class = None if optimizer == "adam" else torch.optim.lr_scheduler.LinearLR + use_best_model = False if learned_round_name == "ada_round" else True + use_amp = True + amp_dtype = torch.float16 + loss_scaling_factor = 1. + optimizer_kwargs = None + lr_scheduler_kwargs = { + "start_factor": 1.0, + "end_factor": 0.0, + "verbose": False,} + learned_round_optimizer = LearnedRoundOptimizer( + learned_round=learned_round, + optimizer_class=optimizer_class, + lr_scheduler_class=lr_scheduler_class, + optimizer_lr=optimizer_lr, + batch_size=batch_size, + iters=iters, + use_best_model=use_best_model, + use_amp=use_amp, + amp_dtype=amp_dtype, + loss_scaling_factor=loss_scaling_factor, + optimizer_kwargs={} if optimizer_kwargs is None else optimizer_kwargs, + lr_scheduler_kwargs=lr_scheduler_kwargs) + cache = CacheCNN() + learned_round_optimizer.apply_learned_round( + model=model, + model_forward=cnn_forward, + block_forward=cnn_block_forward, + data_loader=calibration_loader, + cache=cache, + block_check_fn=block_check_fn, + keep_gpu=True, + ) \ No newline at end of file diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 2cd44443b..bd245b7e6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,17 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import partial -import itertools import math -import re -from typing import Callable, List -from warnings import warn import torch -from torch import nn -import torch.backends.cudnn as cudnn -from torch.optim.optimizer import Optimizer -from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from brevitas.core.function_wrapper.shape import OverBatchOverTensorView @@ -20,23 +12,17 @@ from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode -from brevitas.graph.calibrate import disable_return_quant_tensor -from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import norm_correction_mode -from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gpfq import gpfq_mode from brevitas.graph.gpfq import GPFQv2 from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.optim.sign_sgd import SignSGD from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -656,14 +642,6 @@ def apply_gpfq( gpfq.update() -def _is_resnet_block(module: nn.Module, module_name: str) -> bool: - return (re.search(r"layer\d+", module_name) is not None) - - -def _is_layer(module: nn.Module, module_name: str) -> bool: - return isinstance(module, QuantWBIOL) - - def check_positive_int(*args): """ We check that every inputted value is positive, and an integer. diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index fb2a06e79..6aba4bb70 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -21,16 +21,7 @@ from brevitas.export.inference import quant_inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize -from brevitas.optim.sign_sgd import SignSGD -from brevitas_examples.common.learned_round.learned_round_builder import \ - instantiate_learned_round_optimizer -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - LearnedRoundVisionUtils -from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_layer -from brevitas_examples.imagenet_classification.ptq.ptq_common import _is_resnet_block +from brevitas_examples.imagenet_classification.ptq.learned_round_utils import apply_learned_round from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq @@ -507,38 +498,15 @@ def main(): if args.learned_round: print("Applying Learned Round:") - # Optimizer to tune the rounding - if args.optimizer == "adam": - optimizer_class = torch.optim.Adam - elif args.optimizer == "sign_sgd": - optimizer_class = SignSGD - else: - raise ValueError(f"{args.optimizer} is not a valid optimizer.") - # Granularity of the rounding blocks - if args.learned_round_mode == "layerwise": - block_check_fn = _is_layer - elif args.learned_round_mode == "blockwise": - block_check_fn = _is_resnet_block - # Instantiate optimizer - learned_round_optimizer = instantiate_learned_round_optimizer( - utils_type="imagenet_classification", - method_type=args.learned_round, + apply_learned_round( + model=quant_model, + calibration_loader=calib_loader, + learned_round_name=args.learned_round, + optimizer=args.optimizer, + learned_round_mode=args.learned_round_mode, iters=args.learned_round_iters, - optimizer_params={ - "optimizer_lr": - args.learned_round_lr, - "optimizer_class": - optimizer_class, - "lr_scheduler_class": - None if args.optimizer == "adam" else torch.optim.lr_scheduler.LinearLR, - "batch_size": - args.learned_round_batch_size, - "use_best_model": - False if args.learned_round == "ada_round" else True,}) - learned_round_optimizer.apply_learned_round( - model, - calib_loader, - block_check_fn=block_check_fn, + optimizer_lr=args.learned_round_lr, + batch_size=args.learned_round_batch_size, ) if args.calibrate_bn: diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 223035e53..e10358ca2 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -1,173 +1,60 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, List, Tuple +from typing import Any, Callable, Dict, List, Tuple, Union from accelerate.utils.operations import send_to_device import torch from torch import nn +from torch.optim.lr_scheduler import LinearLR from torch.utils.data.dataloader import DataLoader from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer -from brevitas.graph.calibrate import disable_return_quant_tensor -from brevitas.graph.calibrate import DisableEnableQuantization -from brevitas.graph.calibrate import restore_return_quant_tensor -from brevitas_examples.common.learned_round.learned_round_method import StopFwdException -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundModelUtils - - -class LearnedRoundLLMUtils(LearnedRoundModelUtils): - - def __init__(self, loss_scaling_factor: float = 1000.) -> None: - super(LearnedRoundLLMUtils, self).__init__() - self.llm_cache_state = None - self.loss_scaling_factor = loss_scaling_factor - - def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: - return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) - - class _DataSaverHookLLM: - - def __init__( - self, - cache_args: List, - cache_kwargs: List, - cache_outs: List, - store_args: bool = True, - store_kwargs: bool = True, - store_outs: bool = True, - keep_gpu: bool = True): - self.cache_args = cache_args - self.cache_kwargs = cache_kwargs - self.cache_outs = cache_outs - - self.store_args = store_args - self.store_kwargs = store_kwargs - self.store_outs = store_outs - - self.keep_gpu = keep_gpu - - def __call__(self, module, args, kwargs, output): - # NOTE: If args/kwargs are QuantTensors, should include logic to unpack their values - if isinstance(output, (tuple, list)): - output = output[0] - - # Store each element in the appropiate cache - for element_to_cache, should_cache, cache in zip( - [args, kwargs, output], - [self.store_args, self.store_kwargs, self.store_outs], - [self.cache_args, self.cache_kwargs, self.cache_outs] - ): - if should_cache: - if not self.keep_gpu: - element_to_cache = send_to_device(element_to_cache, 'cpu') - cache.append(element_to_cache) - - raise StopFwdException - - def _save_inp_out_data( - self, - model: nn.Module, - module: nn.Module, - dataloader: DataLoader, - cache_args: List, - cache_kwargs: List, - cache_outs: List, - store_args: bool = True, - store_kwargs: bool = False, - store_outs: bool = True, - keep_gpu: bool = True, - disable_quant=False) -> None: - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) - - device = next(module.parameters()).device - data_saver = LearnedRoundLLMUtils._DataSaverHookLLM( - cache_args, cache_kwargs, cache_outs, store_args, store_kwargs, store_outs, keep_gpu) - handle = module.register_forward_hook(data_saver, with_kwargs=True) - with torch.no_grad(): - for inps in dataloader: - try: - inps = send_to_device(inps, device) - model(**inps) - except StopFwdException: - pass - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - - def init_model_learned_round(self, model: nn.Module) -> None: - self.llm_cache_state = model.config.use_cache - model.config.use_cache = False - - def finish_model_learned_round(self, model: nn.Module) -> None: - model.config.use_cache = self.llm_cache_state - self.llm_cache_state = None - - def init_cache(self) -> Any: - # cache_args, cache_kwargs, cache_outs - return [], [], [] - - def populate_cache( - self, - cache: Any, - model: nn.Module, - block: nn.Module, - data_loader: DataLoader, - keep_gpu: bool = True, - **kwargs, - ) -> int: - # Unpack cache - cache_args, cache_kwargs, cache_outs = cache - # Cache needs to be cleaned between blocks. No need to clear the - # kwargs cache, as this is only updated for the first block. - cache_args.clear() - cache_outs.clear() - # Save FP output - self._save_inp_out_data( - model, - block, - data_loader, - cache_args, - cache_kwargs, - cache_outs, - store_args=False, - store_kwargs=False, - store_outs=True, - keep_gpu=keep_gpu, - disable_quant=True) - # Save Quant input - self._save_inp_out_data( - model, - block, - data_loader, - cache_args, - cache_kwargs, - cache_outs, - store_args=True, - store_kwargs=len(cache_kwargs) == 0, - store_outs=False, - keep_gpu=keep_gpu, - disable_quant=False) - # Return number of samples in calibration set - return len(cache_args) - - def sample_cache( - self, - block: nn.Module, - cache: Any, - indices: torch.Tensor, - input_dim: int = 0, - **kwargs_fn, - ) -> Tuple[Any, torch.Tensor]: - cache_args, cache_kwargs, cache_outs = cache - device = next(block.parameters()).device +from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer + + +class CacheLLM(dict): + + def __init__(self) -> None: + super().__init__() + self.store_kwargs = True + + def store_inputs(self, args, kwargs) -> None: + self["args"].append(args) + if self.store_kwargs: + self["kwargs"].append(kwargs) + + def store_output(self, output) -> None: + if isinstance(output, (tuple, list)): + output = output[0] + self["output"].append(output) + + def initialize_cache(self) -> None: + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] + + def clear_cache(self) -> None: + del self["args"] + del self["output"] + self["args"] = [] + self["output"] = [] + self.store_kwargs = len(self["kwargs"]) == 0 + + def reset_cache(self) -> None: + del self["args"] + del self["kwargs"] + del self["output"] + self.store_kwargs = True + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] + + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + cache_args, cache_kwargs, cache_outs = self["args"], self["kwargs"], self["output"] # Positional arguments args = [cache_args[i] for i in indices] args = tuple(torch.cat(arg_tensor, dim=0) for arg_tensor in zip(*args)) @@ -185,30 +72,84 @@ def sample_cache( kwargs[key] = value for key, value in kwargs.items(): if isinstance(value, list) and len(value) > 0: - kwargs[key] = torch.cat(kwargs[key], dim=input_dim) + kwargs[key] = torch.cat(kwargs[key], dim=0) # FP outputs - outs = torch.cat([cache_outs[i] for i in indices], dim=input_dim) - # Make sure that the inputs and outputs are in the same device as block, - # before running its forward pass. - args = send_to_device(args, device) - kwargs = send_to_device(kwargs, device) - outs = send_to_device(outs, device) - + outs = torch.cat([cache_outs[i] for i in indices], dim=0) return (args, kwargs), outs - def run_forward( - self, - block: nn.Module, - inputs: Any, - ) -> torch.Tensor: - args, kwargs = inputs - quant_outs = block(*args, **kwargs) - if isinstance(quant_outs, tuple): - quant_outs = quant_outs[0] - return quant_outs - - def loss_scaler( - self, - loss: torch.Tensor, - ) -> torch.Tensor: - return loss * self.loss_scaling_factor + def __len__(self): + return len(self["args"]) + + +def llm_learned_round_prepare_fn(model: nn.Module) -> None: + llm_cache_state = model.config.use_cache + model.config.use_cache = False + return llm_cache_state + + +def llm_learned_round_finish_fn(model: nn.Module, llm_cache_state: Dict) -> None: + model.config.use_cache = llm_cache_state + + +def llm_forward(model: nn.Module, inputs: Any) -> None: + device = next(model.parameters()).device + inputs = send_to_device(inputs, device) + model(**inputs) + + +def llm_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: + device = next(block.parameters()).device + args, kwargs = inputs + args = send_to_device(args, device) + kwargs = send_to_device(kwargs, device) + out = block(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + return out + + +def llm_block_check_fn(module: nn.Module, module_name: str) -> bool: + return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) + + +def apply_learned_round(model: nn.Module, calibration_loader: DataLoader) -> None: + iters = 200 + learned_round = AutoRound(iters=200) + optimizer_class = SignSGD + lr_scheduler_class = LinearLR + optimizer_lr = 5e-3 + batch_size = 8 + use_best_model = True + use_amp = True + amp_dtype = torch.float16 + loss_scaling_factor = 1000. + optimizer_kwargs = None + lr_scheduler_kwargs = { + "start_factor": 1.0, + "end_factor": 0.0, + "verbose": False,} + learned_round_optimizer = LearnedRoundOptimizer( + learned_round=learned_round, + optimizer_class=optimizer_class, + lr_scheduler_class=lr_scheduler_class, + optimizer_lr=optimizer_lr, + batch_size=batch_size, + iters=iters, + use_best_model=use_best_model, + use_amp=use_amp, + amp_dtype=amp_dtype, + loss_scaling_factor=loss_scaling_factor, + optimizer_kwargs={} if optimizer_kwargs is None else optimizer_kwargs, + lr_scheduler_kwargs=lr_scheduler_kwargs) + cache = CacheLLM() + learned_round_optimizer.apply_learned_round( + model=model, + model_forward=llm_forward, + block_forward=llm_block_forward, + data_loader=calibration_loader, + cache=cache, + block_check_fn=llm_block_check_fn, + model_prepare_fn=llm_learned_round_prepare_fn, + model_finish_fn=llm_learned_round_finish_fn, + keep_gpu=True, + ) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b426c6ecf..55ea94520 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -23,10 +23,6 @@ from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers -from brevitas_examples.common.learned_round.learned_round_builder import \ - instantiate_learned_round_optimizer -from brevitas_examples.common.learned_round.learned_round_method import AutoRound -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration @@ -38,7 +34,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq from brevitas_examples.llm.llm_quant.gpxq import apply_gptq -from brevitas_examples.llm.llm_quant.learned_round_utils import LearnedRoundLLMUtils +from brevitas_examples.llm.llm_quant.learned_round_utils import apply_learned_round from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch @@ -374,11 +370,7 @@ def main(args): if args.learned_round: print("Applying learned round...") - learned_round_optimizer = instantiate_learned_round_optimizer( - utils_type="llm", - method_type=args.learned_round, - ) - learned_round_optimizer.apply_learned_round(model, calibration_loader) + apply_learned_round(model, calibration_loader) print("Learned round applied.") if args.act_calibration: From b988fa22240c7f54e7f79f5ae6ecd418779bcc07 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 11 Nov 2024 15:24:35 +0000 Subject: [PATCH 17/48] Unify learned round methods --- .../core/function_wrapper/__init__.py | 1 - .../core/function_wrapper/learned_round.py | 73 ++++++++----------- src/brevitas/inject/enum.py | 2 +- src/brevitas/quant/solver/common.py | 6 +- src/brevitas/utils/quant_utils.py | 2 - .../learned_round/learned_round_method.py | 73 +++++++++---------- .../learned_round/learned_round_optimizer.py | 10 +-- 7 files changed, 70 insertions(+), 97 deletions(-) diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index d9aafa978..3b3e5428b 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -5,7 +5,6 @@ from .clamp import FloatClamp from .clamp import ScalarClamp from .clamp import TensorClamp -from .learned_round import AutoRoundSte from .learned_round import LearnedRoundSte from .misc import Identity from .misc import InplaceLogTwo diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 2ed008929..82bf0cbc7 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -25,11 +25,15 @@ class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = -0.1) -> None: super(LearnedRoundHardSigmoid, self).__init__() + self.float_to_int_ste = floor_ste + self.is_p_value = True self.learned_round_zeta = learned_round_zeta self.learned_round_gamma = learned_round_gamma @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: + if training: + return x > 0 p = torch.sigmoid(x) p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma p = torch.clamp(p, 0.0, 1.0) @@ -45,84 +49,65 @@ class LearnedRoundSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_temperature: float = 1.) -> None: super(LearnedRoundSigmoid, self).__init__() assert learned_round_temperature != 0, 'Temperature should be different than 0' + self.float_to_int_ste = floor_ste + self.is_p_value = True self.learned_round_temperature = learned_round_temperature @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: + if training: + return x > 0 p = torch.sigmoid(x / self.learned_round_temperature) return p -# TODO: Change name to AdaRoundSte for consistency -class LearnedRoundSte(brevitas.jit.ScriptModule): +class LearnedRoundIdentity(brevitas.jit.ScriptModule): """ - This Module implements LearnedRound representation, where each weight has a learnable parameter - that decides if "ceil" or "floor" rounding type has to be used. + Implementation for LearnedRound learned parameter + Adapted from https://arxiv.org/abs/2309.05516 """ - def __init__( - self, - learned_round_impl: torch.nn.Module, - learned_round_init: torch.Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: - super(LearnedRoundSte, self).__init__() - self.learned_round_impl = learned_round_impl - learned_round_init = learned_round_init.to(device=device, dtype=dtype) - self.tensor_slicer = SliceTensor() - self.value = torch.nn.Parameter(learned_round_init) + def __init__(self) -> None: + super(LearnedRoundIdentity, self).__init__() + self.float_to_int_ste = round_ste + self.is_p_value = False @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> torch.Tensor: - p = self.p_forward() - p = self.tensor_slicer(p) - return floor_ste(x) + p.to(x.dtype) - - def p_forward(self): - # In eval mode, performs true quantization, otherwise "soft" quantization - if not self.training: - p = (self.value > 0) - else: - p = self.learned_round_impl(self.value) - return p - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - super(LearnedRoundSte, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - value_key = prefix + 'value' - if config.IGNORE_MISSING_KEYS and value_key in missing_keys: - missing_keys.remove(value_key) + def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: + return x -class AutoRoundSte(brevitas.jit.ScriptModule): +class LearnedRoundSte(brevitas.jit.ScriptModule): """ - This Module implements AutoRound representation, where each weight has a learnable parameter + This Module implements LearnedRound representation, where each weight has a learnable parameter that decides if "ceil" or "floor" rounding type has to be used. """ def __init__( self, + learned_round_impl: torch.nn.Module, learned_round_init: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: - super(AutoRoundSte, self).__init__() + super(LearnedRoundSte, self).__init__() + self.learned_round_impl = learned_round_impl learned_round_init = learned_round_init.to(device=device, dtype=dtype) self.tensor_slicer = SliceTensor() self.value = torch.nn.Parameter(learned_round_init) @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> torch.Tensor: - # p should be between [-0.5, 0.5], so this learnable parameter decides whether to "ceil" or "floor" + float_to_int_ste = self.learned_round_impl.float_to_int_ste + is_p_value = self.learned_round_impl.is_p_value p = self.value p = self.tensor_slicer(p) - return round_ste(x + (p.to(x.dtype)).view_as(x)) + p = (p.to(x.dtype)).view_as(x) + return float_to_int_ste(x) + p if is_p_value else float_to_int_ste(x + p) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - super(AutoRoundSte, self)._load_from_state_dict( + super(LearnedRoundSte, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) value_key = prefix + 'value' if config.IGNORE_MISSING_KEYS and value_key in missing_keys: diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index 67122dddb..fbac29176 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -46,7 +46,6 @@ class FloatToIntImplType(AutoName): DPU = auto() LEARNED_ROUND = auto() STOCHASTIC_ROUND = auto() - AUTO_ROUND = auto() class LearnedRoundImplType(AutoName): @@ -54,6 +53,7 @@ class LearnedRoundImplType(AutoName): """ HARD_SIGMOID = auto() SIGMOID = auto() + IDENTITY = auto() class ScalingImplType(AutoName): diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 568505b19..69b4c9438 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -3,8 +3,8 @@ from brevitas.core.bit_width import * from brevitas.core.function_wrapper import * -from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid +from brevitas.core.function_wrapper.learned_round import LearnedRoundIdentity from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.core.function_wrapper.stochastic_round import StochasticRoundSte @@ -50,8 +50,6 @@ def solve_float_to_int_impl_from_enum(impl_type): return LearnedRoundSte elif impl_type == FloatToIntImplType.STOCHASTIC_ROUND: return StochasticRoundSte - elif impl_type == FloatToIntImplType.AUTO_ROUND: - return AutoRoundSte else: raise Exception(f"{impl_type} not recognized.") @@ -150,6 +148,8 @@ def learned_round_impl(learned_round_impl_type): return LearnedRoundSigmoid if learned_round_impl_type == LearnedRoundImplType.HARD_SIGMOID: return LearnedRoundHardSigmoid + if learned_round_impl_type == LearnedRoundImplType.IDENTITY: + return LearnedRoundIdentity @value def learned_round_init(tracked_parameter_list): diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 6ba0ebf76..62290b1de 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -214,8 +214,6 @@ def float_to_int_impl_to_enum(module): return FloatToIntImplType.DPU elif isinstance(module, LearnedRoundSte): return FloatToIntImplType.LEARNED_ROUND - elif isinstance(module, AutoRoundSte): - return FloatToIntImplType.AUTO_ROUND elif isinstance(module, StochasticRoundSte): if module.deterministic_inference: return FloatToIntImplType.ROUND diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index 7c770b0cd..dd0cfa4ab 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -3,13 +3,12 @@ from abc import ABC from abc import abstractmethod -from typing import Generator, List, Tuple +from typing import Dict, Generator, List, Tuple, Type import torch from torch import nn import torch.nn.functional as F -from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.inject.enum import FloatToIntImplType from brevitas.inject.enum import LearnedRoundImplType @@ -23,6 +22,10 @@ class StopFwdException(Exception): class LearnedRoundLoss(ABC): + @abstractmethod + def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: + pass + @abstractmethod def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: pass @@ -34,8 +37,10 @@ def format_loss_components(self, *args) -> str: class LearnedRound(ABC): - def __init__(self, iters: int = 200, **kwargs) -> None: - self.iters = iters + def __init__( + self, loss_cls: Type[LearnedRoundLoss], loss_params: Dict = None, **kwargs) -> None: + self.loss_cls = loss_cls + self.loss_params = loss_params if loss_params is not None else {} def _insert_and_return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: round_modules = [] @@ -55,11 +60,6 @@ def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: def _is_learned_round_module(self, module: nn.Module) -> bool: pass - @abstractmethod - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> LearnedRoundLoss: - pass - def _find_learned_round_modules(self, block: nn.Module) -> List[nn.Module]: round_modules = [] for module in block.modules(): @@ -80,7 +80,7 @@ def learned_round_iterator( for round_module in learned_round_modules: for params in round_module.parameters(): params.requires_grad = True - block_loss = self._instantiate_loss(block, learned_round_modules) + block_loss = self.loss_cls(block, learned_round_modules, **self.loss_params) # Block needs to be in eval mode while the rounding is optimised block.eval() yield block, block_loss, learned_round_modules @@ -102,7 +102,7 @@ def __call__(self, t): return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) -class AdaRoundLoss(LearnedRoundLoss): +class RegularisedMSELoss(LearnedRoundLoss): def __init__( self, @@ -112,8 +112,8 @@ def __init__( max_count: int = 1000, b_range: Tuple = (20, 2), warmup: float = 0.2, - decay_start: float = 0.0) -> None: - super().__init__() + decay_start: float = 0.0, + **kwargs) -> None: # AdaRound operates in a layer-wise manner, so integrity needs to be checked assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." @@ -154,8 +154,9 @@ class AdaRound(LearnedRound): def __init__( self, + loss_cls: Type[LearnedRoundLoss] = RegularisedMSELoss, + loss_params: Dict = None, iters: int = 200, - *, learned_round_zeta: float = 1.1, learned_round_gamma: float = -0.1, learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, @@ -165,16 +166,17 @@ def __init__( decay_start: float = 0.0, **kwargs, ) -> None: - super().__init__(iters, **kwargs) + loss_params = { + "max_count": iters, + "weight": weight, + "b_range": b_range, + "warmup": warmup, + "decay_start": decay_start} if loss_params is None else loss_params + super().__init__(loss_cls, loss_params, **kwargs) # Quantiser-related configuration self.learned_round_zeta = learned_round_zeta self.learned_round_gamma = learned_round_gamma self.learned_round_impl_type = learned_round_impl_type - # Loss-related configuration - self.weight = weight - self.b_range = b_range - self.warmup = warmup - self.decay_start = decay_start def _is_learned_round_module(self, module: nn.Module) -> bool: return isinstance(module, LearnedRoundSte) @@ -191,19 +193,11 @@ def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: learned_round_zeta=self.learned_round_zeta, learned_round_init=value) - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AdaRoundLoss: - return AdaRoundLoss( - block, - learned_round_modules, - max_count=self.iters, - weight=self.weight, - warmup=self.warmup, - decay_start=self.decay_start, - ) +class MSELoss(LearnedRoundLoss): -class AutoRoundLoss(LearnedRoundLoss): + def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: + pass def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: loss = F.mse_loss(pred, tgt) @@ -215,19 +209,20 @@ def format_loss_components(self, loss: float) -> str: class AutoRound(LearnedRound): - def __init__(self, iters: int = 200, **kwargs) -> None: - super().__init__(iters, **kwargs) + def __init__( + self, + loss_cls: Type[LearnedRoundLoss] = MSELoss, + loss_params: Dict = None, + **kwargs) -> None: + super().__init__(loss_cls, loss_params, **kwargs) def _is_learned_round_module(self, module: nn.Module) -> bool: - return isinstance(module, AutoRoundSte) + return isinstance(module, LearnedRoundSte) def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: value = torch.zeros_like(layer.weight.data) layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.AUTO_ROUND, + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=LearnedRoundImplType.IDENTITY, learned_round_init=value, ) - - def _instantiate_loss( - self, block: nn.Module, learned_round_modules: List[nn.Module]) -> AutoRoundLoss: - return AutoRoundLoss() diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index c18a27c8e..ecec393ad 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -186,7 +186,7 @@ from abc import abstractmethod import copy import itertools -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import warnings from accelerate.utils.operations import send_to_device @@ -298,8 +298,8 @@ def __init__( self, learned_round: LearnedRound, *, - optimizer_class: Optimizer = SignSGD, - lr_scheduler_class: LRScheduler = LinearLR, + optimizer_class: Type[Optimizer] = SignSGD, + lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, optimizer_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, @@ -313,10 +313,6 @@ def __init__( "end_factor": 0.0, "verbose": False,} ) -> None: - if learned_round.iters != iters: - warnings.warn( - "The number of iterations passed to the learned round optimiser is different " - "to that of the learned round method, which might lead to unexpected behaviour.") self.learned_round = learned_round self.optimizer_class = optimizer_class self.lr_scheduler_class = lr_scheduler_class From f4bd984280a188489e1757670330610f8815a5a1 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 18 Nov 2024 16:34:04 +0000 Subject: [PATCH 18/48] Refactoring and fixes offload --- .../learned_round/learned_round_method.py | 95 ++----- .../learned_round/learned_round_optimizer.py | 249 +++++++++++------- .../ptq/learned_round_utils.py | 91 +++++-- .../ptq/ptq_evaluate.py | 20 +- .../llm/llm_quant/learned_round_utils.py | 74 ++++-- src/brevitas_examples/llm/main.py | 3 + 6 files changed, 313 insertions(+), 219 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index dd0cfa4ab..04e7b3818 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -3,7 +3,7 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, Generator, List, Tuple, Type +from typing import Dict, Generator, List, Optional, Tuple, Type import torch from torch import nn @@ -37,53 +37,19 @@ def format_loss_components(self, *args) -> str: class LearnedRound(ABC): - def __init__( - self, loss_cls: Type[LearnedRoundLoss], loss_params: Dict = None, **kwargs) -> None: - self.loss_cls = loss_cls - self.loss_params = loss_params if loss_params is not None else {} - - def _insert_and_return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: - round_modules = [] - for module in block.modules(): + def insert_learned_round_quantizers(self, model: nn.Module) -> None: + for module in model.modules(): if isinstance(module, QuantWBIOL) and len( - self._find_learned_round_modules(module)) == 0: + self.return_learned_round_quantizers(module)) == 0: self._insert_learned_round_quantizer_to_layer(module) module.weight_quant.init_tensor_quant(preserve_state_dict=True) - round_modules.append(module.weight_quant.tensor_quant.int_quant.float_to_int_impl) - return round_modules @abstractmethod def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: pass - @abstractmethod - def _is_learned_round_module(self, module: nn.Module) -> bool: - pass - - def _find_learned_round_modules(self, block: nn.Module) -> List[nn.Module]: - round_modules = [] - for module in block.modules(): - if self._is_learned_round_module(module): - round_modules.append(module) - return round_modules - - def learned_round_iterator( - self, - blocks: List[nn.Module]) -> Generator[nn.Module, LearnedRoundLoss, List[nn.Module]]: - for block in blocks: - # Insert learned round quantizers into the appropiate submodules - learned_round_modules = self._insert_and_return_learned_round_quantizers(block) - # Freeze block parameters - for params in block.parameters(): - params.requires_grad = False - # Enable gradient tracking in learned round modules - for round_module in learned_round_modules: - for params in round_module.parameters(): - params.requires_grad = True - block_loss = self.loss_cls(block, learned_round_modules, **self.loss_params) - # Block needs to be in eval mode while the rounding is optimised - block.eval() - yield block, block_loss, learned_round_modules + def return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: + return [module for module in block.modules() if isinstance(module, LearnedRoundSte)] class LinearTempDecay: @@ -146,41 +112,27 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, return total_loss, (total_loss, rec_loss, round_loss, b) def format_loss_components(self, loss: float, rec_loss: float, round_loss: float, b) -> str: - return "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( - loss, rec_loss, round_loss, b) + return "Loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( + loss, + rec_loss.detach().cpu().item(), + round_loss if isinstance(round_loss, float) else round_loss.detach().cpu().item(), + b) class AdaRound(LearnedRound): def __init__( - self, - loss_cls: Type[LearnedRoundLoss] = RegularisedMSELoss, - loss_params: Dict = None, - iters: int = 200, - learned_round_zeta: float = 1.1, - learned_round_gamma: float = -0.1, - learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, - weight: float = 0.01, - b_range: Tuple = (20, 2), - warmup: float = 0.2, - decay_start: float = 0.0, - **kwargs, + self, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1, + learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, + **kwargs, ) -> None: - loss_params = { - "max_count": iters, - "weight": weight, - "b_range": b_range, - "warmup": warmup, - "decay_start": decay_start} if loss_params is None else loss_params - super().__init__(loss_cls, loss_params, **kwargs) # Quantiser-related configuration self.learned_round_zeta = learned_round_zeta self.learned_round_gamma = learned_round_gamma self.learned_round_impl_type = learned_round_impl_type - def _is_learned_round_module(self, module: nn.Module) -> bool: - return isinstance(module, LearnedRoundSte) - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight @@ -201,23 +153,16 @@ def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **k def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: loss = F.mse_loss(pred, tgt) - return loss, (loss,) + return loss, (loss.detach().cpu().item(),) def format_loss_components(self, loss: float) -> str: - return "loss = {:.4f}".format(loss) + return "Loss = {:.4f}".format(loss) class AutoRound(LearnedRound): - def __init__( - self, - loss_cls: Type[LearnedRoundLoss] = MSELoss, - loss_params: Dict = None, - **kwargs) -> None: - super().__init__(loss_cls, loss_params, **kwargs) - - def _is_learned_round_module(self, module: nn.Module) -> bool: - return isinstance(module, LearnedRoundSte) + def __init__(self, **kwargs) -> None: + pass def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: value = torch.zeros_like(layer.weight.data) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index ecec393ad..9bb12d529 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -185,6 +185,7 @@ from abc import ABC from abc import abstractmethod import copy +from functools import partial import itertools from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import warnings @@ -200,11 +201,15 @@ from tqdm import tqdm from brevitas import config +from brevitas.core.function_wrapper.learned_round import LearnedRoundSte from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.optim.sign_sgd import SignSGD +from brevitas_examples.common.accelerate_utils.accelerate import offload_model +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss config.IGNORE_MISSING_KEYS = True @@ -297,6 +302,7 @@ class LearnedRoundOptimizer: def __init__( self, learned_round: LearnedRound, + learned_round_loss_class: Type[LearnedRoundLoss], *, optimizer_class: Type[Optimizer] = SignSGD, lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, @@ -307,11 +313,9 @@ def __init__( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000., - optimizer_kwargs: Dict = None, - lr_scheduler_kwargs: Dict = { - "start_factor": 1.0, - "end_factor": 0.0, - "verbose": False,} + learned_round_loss_kwargs: Optional[Dict] = None, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, ) -> None: self.learned_round = learned_round self.optimizer_class = optimizer_class @@ -323,30 +327,30 @@ def __init__( self.use_amp = use_amp self.amp_dtype = amp_dtype self.loss_scaling_factor = loss_scaling_factor - self.optimizer_kwargs = optimizer_kwargs - - self.lr_scheduler_kwargs = lr_scheduler_kwargs + self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + self.lr_scheduler_kwargs = {} if lr_scheduler_kwargs is None else lr_scheduler_kwargs self.lr_scheduler_kwargs["total_iters"] = self.iters + learned_round_loss_kwargs = {} if learned_round_loss_kwargs is None else learned_round_loss_kwargs + self.learned_round_loss_init = partial( + learned_round_loss_class, **learned_round_loss_kwargs) + + # TODO: FIX @torch.no_grad() def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: for n, m in block.named_modules(): if n in round_params: m.load_state_dict(round_params[n]) + # TODO: FIX @torch.no_grad() def _collect_round_params(self, block: nn.Module) -> Dict: params = {} for n, m in block.named_modules(): - if self.learned_round._is_learned_round_module(m): + if isinstance(m, LearnedRoundSte): params[n] = copy.deepcopy(m.state_dict()) return params - def _scale_loss_and_backward(self, loss: torch.Tensor) -> torch.Tensor: - scaled_loss = loss * self.loss_scaling_factor - scaled_loss.backward() - return scaled_loss - def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: optimizer.step() optimizer.zero_grad() @@ -421,6 +425,83 @@ def _populate_cache( disable_quant=not capture_quant_output, ) + def _optimize_learned_round_block( + self, + block: nn.Module, + block_learned_round_modules: List[nn.Module], + cache: Cache, + block_loss: LearnedRoundLoss, + block_forward: Callable, + ) -> Tuple[float, float, int]: + # Initilalize optimizer and LR scheduler + optimizer = self.optimizer_class( + itertools.chain( + *[ + block_learned_round_module.parameters() + for block_learned_round_module in block_learned_round_modules]), + lr=self.optimizer_lr, + **self.optimizer_kwargs, + ) + lr_scheduler = ( + self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) + if self.lr_scheduler_class else None) + + # Variables needed for printing + best_loss = torch.finfo(torch.float).max + init_loss = -1.0 + last_best_iter = self.iters + + # Dictionary to store the rounding parameters yielding the lowest + # training loss + optimal_rounding_params = {} + + n_samples = len(cache) + pbar = tqdm(range(self.iters), desc='') + for i in pbar: + # Sample mini-batch from cache + idxs = torch.randperm(n_samples)[:self.batch_size] + inputs, fp_outs = cache.sample_batch(idxs) + + # Run block forward to obtain quant outputs + quant_outs = block_forward(block, inputs) + fp_outs = send_to_device(fp_outs, quant_outs.device) + if self.use_amp: + with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", + dtype=self.amp_dtype): + loss, loss_components = block_loss(quant_outs, fp_outs) + else: + loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) + + # Save best parameters before taking gradient step + curr_loss = loss.detach().cpu().item() + init_loss = curr_loss if i == 0 else init_loss + if loss < best_loss: + best_loss = curr_loss + last_best_iter = i + 1 + if self.use_best_model: + optimal_rounding_params = self._collect_round_params(block) + + # Scale loss and perform gradient step + loss = loss * self.loss_scaling_factor + loss.backward() + self._step(optimizer, lr_scheduler) + + # Update progress bar + pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) + + # Make sure no updates are received in the progress bar + pbar.close() + + if self.use_best_model: + with torch.no_grad(): + self._load_round_params(block, optimal_rounding_params) + else: + # Override if the model with the lowest training error is not used + best_loss = curr_loss + last_best_iter = self.iters + + return init_loss, best_loss, last_best_iter + def apply_learned_round( self, model: nn.Module, @@ -433,43 +514,29 @@ def apply_learned_round( model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True) -> None: + # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs model_dict = None if model_prepare_fn is None else model_prepare_fn(model) + # Insert quantizers within the appropiate model blocks + self.learned_round.insert_learned_round_quantizers(model) + # Retrieve blocks using the appropiate function to check blocks blocks = get_blocks(model, block_check_fn) print(f"Total Iterations per block {self.iters}") print(f"Number of blocks {len(blocks)}") - # Initialise cache to store partial inputs and outputs for each block + # Initialize cache to store partial inputs and outputs for each block cache.initialize_cache() - # Loop across blocks to optimise rounding within each - for block_idx, (block, block_loss, block_learned_round_modules) in enumerate( - self.learned_round.learned_round_iterator(blocks)): - # Initialise optimiser and LR scheduler - optimizer = self.optimizer_class( - itertools.chain( - *[ - learned_round_module.parameters() - for learned_round_module in block_learned_round_modules]), - lr=self.optimizer_lr, - **self.optimizer_kwargs, - ) - lr_scheduler = ( - self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) - if self.lr_scheduler_class else None) - - # Variables needed for printing - best_loss = torch.finfo(torch.float).max - init_loss = -1.0 - last_best_iter = self.iters - - optimal_rounding_params = {} - + # Iterate over blocks and optimise the rounding parameters within each of them + for block_idx, block in enumerate(blocks): + # Distribute the model across devices to run a forward pass to capture + # inputs/outputs to the given block + model = offload_model(model) + # Cache needs to be cleared before populating it with the inputs and outputs + # to the block under optimization. cache.clear_cache() - torch.cuda.empty_cache() - # Populate cache for the given block self._populate_cache( cache, model, @@ -480,66 +547,58 @@ def apply_learned_round( capture_quant_input=True, capture_quant_output=False, ) - # Retrieve number of samples - n_samples = len(cache) - # Enable training model in quantizer modules - for learned_round_module in block_learned_round_modules: - learned_round_module.train() - - pbar = tqdm(range(self.iters), desc='') - for i in pbar: - # Sample mini-batch from cache - idxs = torch.randperm(n_samples)[:self.batch_size] - inputs, fp_outs = cache.sample_batch(idxs) - - # Run block forward to obtain quant outputs - quant_outs = block_forward(block, inputs) - - if self.use_amp: - with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", - dtype=self.amp_dtype): - loss, loss_components = block_loss(quant_outs, fp_outs) - else: - loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) - - init_loss = loss.item() if i == 0 else init_loss - - if loss < best_loss: - best_loss = loss.item() - last_best_iter = i + 1 - if self.use_best_model: - optimal_rounding_params = self._collect_round_params(block) - - # Scale loss and perform gradient step - self._scale_loss_and_backward(loss) - self._step(optimizer, lr_scheduler) - - # Update progress bar - pbar.set_description( - "Block = {:d}/{:d}, {}".format( - block_idx + 1, - len(blocks), - block_loss.format_loss_components(*loss_components))) - # Make sure no updates are received in the progress bar - pbar.close() - # Reset cache for other blocks - cache.reset_cache() - - # Set back quantizers to eval mode - for learned_round_module in block_learned_round_modules: - learned_round_module.eval() + # Remove hooks needed to offload the model blocks to cpu + remove_hooks(model) + + # The parameters of the block that are not part of the rounding quantizers + # need to be frozen, as only the rounding needs to be optimized. + block.eval() + for params in block.parameters(): + params.requires_grad = False + # However, the rounding parameters are tuned + block_learned_round_modules = self.learned_round.return_learned_round_quantizers(block) + for block_learned_round_module in block_learned_round_modules: + block_learned_round_module.train() + for params in block_learned_round_module.parameters(): + params.requires_grad = True + + # Move block to GPU if available + if torch.cuda.is_available(): + block.cuda() + + # Loss function for computing the rounding loss within each block + block_loss = self.learned_round_loss_init( + block, + block_learned_round_modules, + ) - if self.use_best_model: - self._load_round_params(block, optimal_rounding_params) - else: - # Override if the model with the lowest training error is not used - best_loss = loss.item() - last_best_iter = self.iters + # Optimize block rounding + init_loss, best_loss, last_best_iter = self._optimize_learned_round_block( + block=block, + block_learned_round_modules=block_learned_round_modules, + cache=cache, + block_loss=block_loss, + block_forward=block_forward, + ) print( f"Quantized block {block_idx+1}/{len(blocks)}, " f"initial loss: {init_loss:.6f}, best loss: {best_loss:.6f}, at iteration {last_best_iter}." ) - # Finish optimisation + + # After finishing the optimization, the block rounding parameters are frozen + for block_learned_round_module in block_learned_round_modules: + block_learned_round_module.eval() + for params in block_learned_round_module.parameters(): + params.requires_grad = False + + # Move the block back to CPU + block.cpu() + + # TODO: This call might not be needed, check_clear and reset_cache methods + # Reset cache after optimisation + cache.reset_cache() + + # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: model_finish_fn(model, model_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 2ba90e9df..9dfef86c5 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -27,12 +27,13 @@ # SOFTWARE. import re -from typing import Any, Callable, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import warnings from accelerate.utils.operations import send_to_device import torch from torch import nn +from torch.optim.lr_scheduler import LinearLR from torch.utils.data.dataloader import DataLoader from brevitas import config @@ -41,11 +42,37 @@ from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.learned_round.learned_round_method import AdaRound from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import MSELoss +from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer config.IGNORE_MISSING_KEYS = True +def is_resnet_block(module: nn.Module, module_name: str) -> bool: + return (re.search(r"layer\d+", module_name) is not None) + + +def is_layer(module: nn.Module, module_name: str) -> bool: + return isinstance(module, QuantWBIOL) + + +LEARNED_ROUND_MAP = { + "auto_round": AutoRound, + "ada_round": AdaRound,} +LEARNED_ROUND_LOSS_MAP = { + "mse": MSELoss, + "regularised_mse": RegularisedMSELoss,} +OPTIMIZER_MAP = { + "adam": torch.optim.Adam, + "sign_sgd": SignSGD,} +BLOCK_CHECK_MAP = { + "layerwise": is_layer, + "sign_sgd": is_resnet_block,} +LR_SCHEDULER_MAP = { + "linear": LinearLR,} + + class CacheCNN(dict): def __init__(self) -> None: @@ -114,53 +141,58 @@ def cnn_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: return block(inputs) -def is_resnet_block(module: nn.Module, module_name: str) -> bool: - return (re.search(r"layer\d+", module_name) is not None) - - -def is_layer(module: nn.Module, module_name: str) -> bool: - return isinstance(module, QuantWBIOL) - - def apply_learned_round( model: nn.Module, calibration_loader: DataLoader, - learned_round_name: str = "ada_round", - optimizer: str = "adam", - learned_round_mode: str = "layerwise", iters: int = 1000, + learned_round: str = "ada_round", + learned_round_loss: str = "regularised_mse", + optimizer: str = "adam", + lr_scheduler: Optional[str] = None, optimizer_lr: float = 1e-3, batch_size: int = 1, + use_best_model: bool = False, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1., + learned_round_loss_kwargs: Optional[Dict] = None, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_mode: str = "layerwise", ) -> None: + if learned_round not in LEARNED_ROUND_MAP: + raise ValueError(f"Learned round method {learned_round} is not available.") + learned_round = LEARNED_ROUND_MAP[learned_round]() + + if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: + raise ValueError(f"Learned round loss {learned_round_loss} is not available.") + learned_round_loss_class = LEARNED_ROUND_LOSS_MAP[learned_round_loss] + + if optimizer not in OPTIMIZER_MAP: + raise ValueError(f"Optimizer {optimizer} is not available.") + optimizer_class = OPTIMIZER_MAP[optimizer] + + if lr_scheduler is not None and lr_scheduler not in LR_SCHEDULER_MAP: + raise ValueError(f"Learning rate scheduler {lr_scheduler} is not available.") + lr_scheduler_class = None if lr_scheduler is None else LR_SCHEDULER_MAP[lr_scheduler] + optimizer_classes = {"adam": torch.optim.Adam, "sign_sgd": SignSGD} if optimizer not in optimizer_classes: raise ValueError(f"{optimizer} is not a valid optimizer.") optimizer_class = optimizer_classes[optimizer] - block_check_fns = {"layerwise": is_layer, "blockwise": is_resnet_block} - if learned_round_mode not in block_check_fns: + if learned_round_mode not in BLOCK_CHECK_MAP: learned_round_mode = "layerwise" warnings.warn( f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") - block_check_fn = block_check_fns[learned_round_mode] - - learned_round_methods = {"ada_round": AdaRound, "auto_round": AutoRound} - if learned_round_name not in learned_round_methods: - raise ValueError(f"Learned round method {learned_round_name} is not available.") - learned_round = learned_round_methods[learned_round_name](iters=iters) - - lr_scheduler_class = None if optimizer == "adam" else torch.optim.lr_scheduler.LinearLR - use_best_model = False if learned_round_name == "ada_round" else True - use_amp = True - amp_dtype = torch.float16 - loss_scaling_factor = 1. - optimizer_kwargs = None + block_check_fn = BLOCK_CHECK_MAP[learned_round_mode] lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, - "verbose": False,} + "verbose": False,} if lr_scheduler_kwargs is None else lr_scheduler_kwargs learned_round_optimizer = LearnedRoundOptimizer( learned_round=learned_round, + learned_round_loss_class=learned_round_loss_class, optimizer_class=optimizer_class, lr_scheduler_class=lr_scheduler_class, optimizer_lr=optimizer_lr, @@ -170,7 +202,8 @@ def apply_learned_round( use_amp=use_amp, amp_dtype=amp_dtype, loss_scaling_factor=loss_scaling_factor, - optimizer_kwargs={} if optimizer_kwargs is None else optimizer_kwargs, + learned_round_loss_kwargs=learned_round_loss_kwargs, + optimizer_kwargs=optimizer_kwargs, lr_scheduler_kwargs=lr_scheduler_kwargs) cache = CacheCNN() learned_round_optimizer.apply_learned_round( diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 6aba4bb70..65b9e07c6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -163,6 +163,12 @@ def validate_args(args): default=None, type=str, choices=[None, 'ada_round', 'auto_round'], + help='Learned round type (default: None)') +parser.add_argument( + '--learned-round-loss', + default='regularised_mse', + type=str, + choices=['regularised_mse', 'mse'], help='Learned round type (default: none)') parser.add_argument( '--learned-round-mode', @@ -174,6 +180,12 @@ def validate_args(args): default=1000, type=int, help='Numbers of iterations for learned round for each layer (default: 1000)') +parser.add_argument( + '--learned-round-lr-scheduler', + default=None, + type=str, + choices=[None, 'linear'], + help='Learning rate scheduler for learned round (default: None)') parser.add_argument( '--learned-round-lr', default=1e-3, @@ -501,12 +513,14 @@ def main(): apply_learned_round( model=quant_model, calibration_loader=calib_loader, - learned_round_name=args.learned_round, - optimizer=args.optimizer, - learned_round_mode=args.learned_round_mode, iters=args.learned_round_iters, + learned_round=args.learned_round, + learned_round_loss=args.learned_round_loss, + optimizer=args.optimizer, + lr_scheduler=args.learned_round_lr_scheduler, optimizer_lr=args.learned_round_lr, batch_size=args.learned_round_batch_size, + learned_round_mode=args.learned_round_mode, ) if args.calibrate_bn: diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index e10358ca2..c384dd559 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -1,20 +1,34 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from accelerate.utils.operations import send_to_device import torch from torch import nn from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer from torch.utils.data.dataloader import DataLoader from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer from brevitas.optim.sign_sgd import SignSGD from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss +from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +LEARNED_ROUND_MAP = { + "auto_round": AutoRound,} +LEARNED_ROUND_LOSS_MAP = { + "mse": MSELoss,} +OPTIMIZER_MAP = { + "sign_sgd": SignSGD,} +LR_SCHEDULER_MAP = { + "linear": LinearLR,} + class CacheLLM(dict): @@ -44,6 +58,7 @@ def clear_cache(self) -> None: self["output"] = [] self.store_kwargs = len(self["kwargs"]) == 0 + # TODO: Rename to remove cache def reset_cache(self) -> None: del self["args"] del self["kwargs"] @@ -93,7 +108,8 @@ def llm_learned_round_finish_fn(model: nn.Module, llm_cache_state: Dict) -> None def llm_forward(model: nn.Module, inputs: Any) -> None: device = next(model.parameters()).device - inputs = send_to_device(inputs, device) + if device != torch.device("meta"): + inputs = send_to_device(inputs, device) model(**inputs) @@ -112,24 +128,47 @@ def llm_block_check_fn(module: nn.Module, module_name: str) -> bool: return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) -def apply_learned_round(model: nn.Module, calibration_loader: DataLoader) -> None: - iters = 200 - learned_round = AutoRound(iters=200) - optimizer_class = SignSGD - lr_scheduler_class = LinearLR - optimizer_lr = 5e-3 - batch_size = 8 - use_best_model = True - use_amp = True - amp_dtype = torch.float16 - loss_scaling_factor = 1000. - optimizer_kwargs = None +def apply_learned_round( + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 200, + learned_round: str = "auto_round", + learned_round_loss: str = "mse", + optimizer: str = "sign_sgd", + lr_scheduler: Optional[str] = "linear", + optimizer_lr: float = 5e-3, + batch_size: int = 8, + use_best_model: bool = True, + use_amp: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000, + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_loss_kwargs: Optional[Dict] = None, +) -> None: + if learned_round not in LEARNED_ROUND_MAP: + raise ValueError(f"Learned round method {learned_round} is not available.") + learned_round = LEARNED_ROUND_MAP[learned_round]() + + if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: + raise ValueError(f"Learned round loss {learned_round_loss} is not available.") + learned_round_loss_class = LEARNED_ROUND_LOSS_MAP[learned_round_loss] + + if optimizer not in OPTIMIZER_MAP: + raise ValueError(f"Optimizer {optimizer} is not available.") + optimizer_class = OPTIMIZER_MAP[optimizer] + + if lr_scheduler is not None and lr_scheduler not in LR_SCHEDULER_MAP: + raise ValueError(f"Learning rate scheduler {lr_scheduler} is not available.") + lr_scheduler_class = None if lr_scheduler is None else LR_SCHEDULER_MAP[lr_scheduler] + lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, - "verbose": False,} + "verbose": False,} if lr_scheduler_kwargs is None else lr_scheduler_kwargs learned_round_optimizer = LearnedRoundOptimizer( learned_round=learned_round, + learned_round_loss_class=learned_round_loss_class, optimizer_class=optimizer_class, lr_scheduler_class=lr_scheduler_class, optimizer_lr=optimizer_lr, @@ -139,7 +178,8 @@ def apply_learned_round(model: nn.Module, calibration_loader: DataLoader) -> Non use_amp=use_amp, amp_dtype=amp_dtype, loss_scaling_factor=loss_scaling_factor, - optimizer_kwargs={} if optimizer_kwargs is None else optimizer_kwargs, + learned_round_loss_kwargs=learned_round_loss_kwargs, + optimizer_kwargs=optimizer_kwargs, lr_scheduler_kwargs=lr_scheduler_kwargs) cache = CacheLLM() learned_round_optimizer.apply_learned_round( @@ -151,5 +191,5 @@ def apply_learned_round(model: nn.Module, calibration_loader: DataLoader) -> Non block_check_fn=llm_block_check_fn, model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, - keep_gpu=True, + keep_gpu=False, ) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 55ea94520..04385276b 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -370,9 +370,12 @@ def main(args): if args.learned_round: print("Applying learned round...") + remove_hooks(model) apply_learned_round(model, calibration_loader) print("Learned round applied.") + model = offload_model(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) From 494a01761f9d51a823bde795e6acaf669b28b135 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 19 Nov 2024 16:11:37 +0000 Subject: [PATCH 19/48] Initial implementation distributed training --- .../common/accelerate_utils/accelerate.py | 7 +- .../learned_round/learned_round_optimizer.py | 157 +++++++++++++++++- .../ptq/learned_round_utils.py | 9 +- .../llm/llm_quant/learned_round_utils.py | 31 +++- 4 files changed, 199 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..876837081 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -405,8 +405,13 @@ def offload_model( device_map = infer_fx_auto_device_map(model, memory_map) offload_call_function(model, device_map) else: + # Some models do no have the attribute _no_split_modules, so a check is needed to prevent + # this call to crash. device_map = infer_auto_device_map( - model, memory_map, no_split_module_classes=model._no_split_modules) + model, + memory_map, + no_split_module_classes=model._no_split_modules + if hasattr(model, "_no_split_modules") else None) model = dispatch_model(model, device_map) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 9bb12d529..58a350662 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -190,7 +190,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import warnings +from accelerate import Accelerator +from accelerate.utils import tqdm as tqdm_accelerate +from accelerate.utils.dataclasses import PrecisionType from accelerate.utils.operations import send_to_device +from datasets import Dataset import torch from torch import autocast from torch import nn @@ -198,6 +202,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataloader import RandomSampler from tqdm import tqdm from brevitas import config @@ -268,6 +273,14 @@ def clear_cache(self) -> None: def reset_cache(self) -> None: pass + @abstractmethod + def cache_to_dataset(self) -> Dataset: + pass + + @abstractmethod + def collate_fn(self, batch: Any) -> Any: + pass + class DataSaverHook: @@ -313,6 +326,7 @@ def __init__( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000., + use_accelerate: bool = False, learned_round_loss_kwargs: Optional[Dict] = None, optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, @@ -335,6 +349,10 @@ def __init__( self.learned_round_loss_init = partial( learned_round_loss_class, **learned_round_loss_kwargs) + # TODO: Remove once validated and expose the flag + # self.use_accelerate = use_accelerate + self.use_accelerate = False + # TODO: FIX @torch.no_grad() def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: @@ -433,7 +451,19 @@ def _optimize_learned_round_block( block_loss: LearnedRoundLoss, block_forward: Callable, ) -> Tuple[float, float, int]: - # Initilalize optimizer and LR scheduler + # Move block to GPU if available + if torch.cuda.is_available(): + try: + block.cuda() + except RuntimeError as exc: + if 'out of memory' in str(exc): + warnings.warn( + "Out of memory error was raised when moving the block to GPU. Defaulting to CPU." + ) + else: + raise exc + + # Initialize optimizer and LR scheduler optimizer = self.optimizer_class( itertools.chain( *[ @@ -500,6 +530,125 @@ def _optimize_learned_round_block( best_loss = curr_loss last_best_iter = self.iters + # Move the block back to CPU + block.cpu() + + return init_loss, best_loss, last_best_iter + + # TODO: Enable saving best parameters + def _accelerate_optimize_learned_round_block( + self, + block: nn.Module, + block_learned_round_modules: List[nn.Module], + cache: Cache, + block_loss: LearnedRoundLoss, + block_forward: Callable, + ) -> Tuple[float, float, int]: + # Enable running in mixed precision + TORCH_DTYPE_TO_PRECISION_TYPE_MAP = { + torch.float16: PrecisionType.FP16, + torch.bfloat16: PrecisionType.BF16,} + raise_warning_dtype = False + if not self.use_amp: + mixed_precision_type = None + else: + if self.amp_dtype not in TORCH_DTYPE_TO_PRECISION_TYPE_MAP: + raise_warning_dtype = True + mixed_precision_type = None + else: + mixed_precision_type = TORCH_DTYPE_TO_PRECISION_TYPE_MAP[self.amp_dtype] + # Instantiate accelerator to run in a multi-GPU setting + accelerator = Accelerator(mixed_precision=mixed_precision_type) + + # Raise warning if the AMP dtype was defaulted to float32. This warning is raised after + # the instantiation of accelerator, to use its print functionality so the message is only + # printed once. + if raise_warning_dtype: + accelerator.print( + f"The dtype {self.amp_dtype} cannot be used for AMP training with accelerate. Defaulting to float32." + ) + + # Initilalize optimizer and LR scheduler + optimizer = self.optimizer_class( + itertools.chain( + *[ + block_learned_round_module.parameters() + for block_learned_round_module in block_learned_round_modules]), + lr=self.optimizer_lr, + **self.optimizer_kwargs, + ) + lr_scheduler = ( + self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) + if self.lr_scheduler_class else None) + + # Prepare dataset from cache + cache_dataset = cache.cache_to_dataset() + # NOTE: Intuitively, the total samples retrieved during optimization should + # be self.batch_size*self.iters. However, a StopIteration is raised mid-training + # signaling that this is not correct. Should check why this is the case. + random_sampler = RandomSampler( + cache_dataset, replacement=True, num_samples=2 * self.batch_size * self.iters) + cache_dataloader = DataLoader( + cache_dataset, + batch_size=self.batch_size, + sampler=random_sampler, + collate_fn=cache.collate_fn) + + # Prepare elements for training + cache_dataloader, block, optimizer, lr_scheduler = accelerator.prepare(cache_dataloader, block, optimizer, lr_scheduler) + + # Variables needed for printing + best_loss = torch.finfo(torch.float).max + init_loss = -1.0 + last_best_iter = self.iters + + # Initialize an iterator to extract elements from the cache dataloader + cache_iterator = iter(cache_dataloader) + + pbar = tqdm_accelerate(range(self.iters), desc='') + for i in pbar: + # Sample mini-batch from cache + inputs, fp_outs = next(cache_iterator) + + # Run block forward to obtain quant outputs + quant_outs = block_forward(block, inputs) + # Compute loss using the block loss function + loss, loss_components = block_loss(quant_outs, fp_outs) + + # Save best parameters before taking gradient step + curr_loss = loss.detach().cpu().item() + init_loss = curr_loss if i == 0 else init_loss + if loss < best_loss: + best_loss = curr_loss + last_best_iter = i + 1 + + # Scale loss and perform gradient step + # loss = loss * self.loss_scaling_factor + accelerator.backward(loss) + self._step(optimizer, lr_scheduler) + + # Update progress bar + pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) + + # Make sure no updates are received in the progress bar + pbar.close() + + # TODO: Include support for saving the best configuration during training + if not self.use_best_model: + # Override if the model with the lowest training error is not used + best_loss = curr_loss + last_best_iter = self.iters + + # TODO: Verify if this call is actually needed + # Wait for everyone before proceding to next block + accelerator.wait_for_everyone() + # Remove all the wrapper around the block + block = accelerator.unwrap_model(block) + # Clear memory + accelerator.free_memory() + # Move the block back to CPU + block.cpu() + return init_loss, best_loss, last_best_iter def apply_learned_round( @@ -573,7 +722,11 @@ def apply_learned_round( ) # Optimize block rounding - init_loss, best_loss, last_best_iter = self._optimize_learned_round_block( + init_loss, best_loss, last_best_iter = ( + self._optimize_learned_round_block + if not self.use_accelerate + else self._accelerate_optimize_learned_round_block + )( block=block, block_learned_round_modules=block_learned_round_modules, cache=cache, diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 9dfef86c5..8d83d6510 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -31,6 +31,7 @@ import warnings from accelerate.utils.operations import send_to_device +from datasets import Dataset import torch from torch import nn from torch.optim.lr_scheduler import LinearLR @@ -68,7 +69,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool: "sign_sgd": SignSGD,} BLOCK_CHECK_MAP = { "layerwise": is_layer, - "sign_sgd": is_resnet_block,} + "blockwise": is_resnet_block,} LR_SCHEDULER_MAP = { "linear": LinearLR,} @@ -122,6 +123,12 @@ def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: return self["inputs"][indices], self["output"][indices] + def cache_to_dataset(self) -> Dataset: + raise NotImplementedError("This method is still not available for CNNs.") + + def collate_fn(self, batch: Any) -> Any: + raise NotImplementedError("This method is still not available for CNNs.") + def __len__(self): return ( len(self["inputs"]) diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index c384dd559..dd0842702 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from accelerate.utils.operations import send_to_device +from datasets import Dataset import torch from torch import nn from torch.optim.lr_scheduler import LinearLR @@ -58,7 +59,6 @@ def clear_cache(self) -> None: self["output"] = [] self.store_kwargs = len(self["kwargs"]) == 0 - # TODO: Rename to remove cache def reset_cache(self) -> None: del self["args"] del self["kwargs"] @@ -92,6 +92,35 @@ def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: outs = torch.cat([cache_outs[i] for i in indices], dim=0) return (args, kwargs), outs + def cache_to_dataset(self) -> Dataset: + inputs_list = list(zip(self["args"], self["kwargs"])) + return list(zip(inputs_list, self["output"])) + + def collate_fn(self, batch: Any) -> Any: + # Format of the dataset is ((args, kwargs), outs) + # See cache_to_dataset + inputs, outs = map(list, zip(*batch)) + args, kwargs_dict = map(list, zip(*inputs)) + # Positional arguments + args = tuple(torch.cat(arg_tensor, dim=0) for arg_tensor in zip(*args)) + # Keyword arguments + kwargs = {} + for curr_dict in kwargs_dict: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + # FP outputs + outs = torch.cat(outs, dim=0) + return ((args, kwargs), outs) + def __len__(self): return len(self["args"]) From d79bfdc8df6d6b60252f84fe7b4163f5c2bd4c49 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 19 Nov 2024 16:19:47 +0000 Subject: [PATCH 20/48] Minor cleanup --- .../common/learned_round/learned_round_optimizer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 58a350662..9c964333b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -353,14 +353,12 @@ def __init__( # self.use_accelerate = use_accelerate self.use_accelerate = False - # TODO: FIX @torch.no_grad() def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: for n, m in block.named_modules(): if n in round_params: m.load_state_dict(round_params[n]) - # TODO: FIX @torch.no_grad() def _collect_round_params(self, block: nn.Module) -> Dict: params = {} @@ -523,8 +521,7 @@ def _optimize_learned_round_block( pbar.close() if self.use_best_model: - with torch.no_grad(): - self._load_round_params(block, optimal_rounding_params) + self._load_round_params(block, optimal_rounding_params) else: # Override if the model with the lowest training error is not used best_loss = curr_loss From 4042cc8e9db3ccbdb58aac0654f79b61e0009a35 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 20 Nov 2024 17:54:10 +0000 Subject: [PATCH 21/48] Fix tests --- .../core/function_wrapper/learned_round.py | 8 +- .../learned_round/learned_round_optimizer.py | 68 +-- tests/brevitas/core/test_float_to_int.py | 135 +++--- tests/brevitas/hyp_helper.py | 10 +- tests/brevitas/optim/test_sign_sgd.py | 316 ++++--------- tests/brevitas_examples/test_imagenet.py | 34 -- .../test_learned_round_utils.py | 441 +++++++++--------- 7 files changed, 436 insertions(+), 576 deletions(-) delete mode 100644 tests/brevitas_examples/test_imagenet.py diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 82bf0cbc7..2d3e76aeb 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -32,11 +32,11 @@ def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = @brevitas.jit.script_method def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - if training: - return x > 0 p = torch.sigmoid(x) p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma p = torch.clamp(p, 0.0, 1.0) + if not training: + return p > 0.5 return p @@ -55,7 +55,7 @@ def __init__(self, learned_round_temperature: float = 1.) -> None: @brevitas.jit.script_method def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - if training: + if not training: return x > 0 p = torch.sigmoid(x / self.learned_round_temperature) return p @@ -99,7 +99,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: float_to_int_ste = self.learned_round_impl.float_to_int_ste is_p_value = self.learned_round_impl.is_p_value - p = self.value + p = self.learned_round_impl(self.value, self.training) p = self.tensor_slicer(p) p = (p.to(x.dtype)).view_as(x) return float_to_int_ste(x) + p if is_p_value else float_to_int_ste(x + p) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 9c964333b..cf9dde3ce 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -310,6 +310,38 @@ def __call__(self, module, args, kwargs, output) -> None: raise StopFwdException +def save_inputs_output( + model: nn.Module, + model_forward: Callable, + module: nn.Module, + dataloader: DataLoader, + cache: Cache, + store_inputs: bool = True, + store_output: bool = False, + keep_gpu: bool = True, + disable_quant: bool = False) -> None: + if disable_quant: + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(model, False) + disable_quant_class.disable_param_quantization(model, False) + return_quant_tensor_state = disable_return_quant_tensor(model) + + data_saver = DataSaverHook( + cache, store_inputs=store_inputs, store_output=store_output, keep_gpu=keep_gpu) + handle = module.register_forward_hook(data_saver, with_kwargs=True) + with torch.no_grad(): + for inps in dataloader: + try: + model_forward(model, inps) + except StopFwdException: + pass + handle.remove() + if disable_quant: + disable_quant_class.enable_act_quantization(model, False) + disable_quant_class.enable_param_quantization(model, False) + restore_return_quant_tensor(model, return_quant_tensor_state) + + class LearnedRoundOptimizer: def __init__( @@ -373,38 +405,6 @@ def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: if lr_scheduler: lr_scheduler.step() - def _save_inputs_output( - self, - model: nn.Module, - model_forward: Callable, - module: nn.Module, - dataloader: DataLoader, - cache: Cache, - store_inputs: bool = True, - store_output: bool = False, - keep_gpu: bool = True, - disable_quant: bool = False) -> None: - if disable_quant: - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(model, False) - disable_quant_class.disable_param_quantization(model, False) - return_quant_tensor_state = disable_return_quant_tensor(model) - - data_saver = DataSaverHook( - cache, store_inputs=store_inputs, store_output=store_output, keep_gpu=keep_gpu) - handle = module.register_forward_hook(data_saver, with_kwargs=True) - with torch.no_grad(): - for inps in dataloader: - try: - model_forward(model, inps) - except StopFwdException: - pass - handle.remove() - if disable_quant: - disable_quant_class.enable_act_quantization(model, False) - disable_quant_class.enable_param_quantization(model, False) - restore_return_quant_tensor(model, return_quant_tensor_state) - def _populate_cache( self, cache: Cache, @@ -417,7 +417,7 @@ def _populate_cache( capture_quant_output: bool = False, ) -> None: # Populate the cache with new inputs and outputs - self._save_inputs_output( + save_inputs_output( model, model_forward, block, @@ -429,7 +429,7 @@ def _populate_cache( disable_quant=not capture_quant_input, ) if capture_quant_input != capture_quant_output: - self._save_inputs_output( + save_inputs_output( model, model_forward, block, diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index 0ec7f10e6..41b74b4d6 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -1,16 +1,19 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from hypothesis import given +from hypothesis.strategies import floats import pytest import pytest_cases import torch from brevitas import config -from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid +from brevitas.core.function_wrapper.learned_round import LearnedRoundIdentity from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte import brevitas.nn as qnn +from tests.brevitas.hyp_helper import two_float_tensor_random_shape_st OUT_CH = 16 IN_CH = 8 @@ -19,70 +22,94 @@ LearnedRoundSigmoid(), # Sigmoid Implementation LearnedRoundSigmoid(learned_round_temperature=2.), # Sigmoid + Temperature LearnedRoundHardSigmoid(), # Hard Sigmoid + LearnedRoundIdentity(), # AutoRound Implement ] class TestLearnedRound(): - @pytest_cases.fixture() + def instantiate_learnedround_float_to_int_impl(self, impl, weights, value): + impl = LearnedRoundSte(impl, torch.full(weights.shape, 0.)) + # For methods with p_value=False, it is required that value is within [-0.5, 0.5] + if not impl.learned_round_impl.is_p_value: + min_value, max_value = torch.min(value), torch.max(value) + # Prevent division by zero when all the elements of the tensor are the same + if max_value - min_value < 1e-8: + # Make sure that the division is safe + if torch.abs(max_value) > 1e-8: + value = value / max_value - 0.5 + else: + value = (value - min_value) / (max_value - min_value) - 0.5 + # Simulate learned round + impl.value.data = value + return impl + + # NOTE: The min/max values are set to the exactly representable float32 + # closer to sys.maxsize for a given machine. @pytest_cases.parametrize('impl', LEARNEDROUND_IMPL) - def learnedround_float_to_int_impl(self, impl): - sample_weight = torch.randn(OUT_CH, IN_CH, KERNEL_SIZE, KERNEL_SIZE) - impl = LearnedRoundSte(impl, torch.full(sample_weight.shape, 0.)) - - # Simulate learned parameter - impl.value.data = torch.randn_like(impl.value) - return impl, sample_weight - @pytest_cases.parametrize('training', [True, False]) - def test_learnedround(self, learnedround_float_to_int_impl, training): - impl, sample_weight = learnedround_float_to_int_impl + @given( + weights_value=two_float_tensor_random_shape_st( + min_val=-9.223372036854776e+18, max_val=9.223372036854776e+18)) + def test_learnedround(self, impl, training, weights_value): + # Unpack tuple of hypothesis generated tensors + weights, value = weights_value + # Instantiate LearnedRoundSte using fabric method + impl = self.instantiate_learnedround_float_to_int_impl(impl, weights, value) impl.train(training) - - out = impl(sample_weight) - if training: - # Soft quantization. All values are at most distant +/- 1 from the nearest integer - assert torch.all(torch.abs(out - torch.round(out)) < 1) + print(impl.value) + out = impl(weights) + # The FP values and its quantized values must differ by at most +/- 1 + assert torch.all(torch.abs(out - weights) <= 1) + # For is_p_value=True, the rounding can be soft while training=True + if impl.learned_round_impl.is_p_value: + if training: + # Soft quantization. All values are at most distant +/- 1 from the nearest integer + assert torch.all(torch.abs(out - torch.round(out)) <= 1) + else: + # Hard quantization. All values are integers + assert torch.allclose(out, torch.round(out)) else: - # Hard quantization. All values are integers + # All values should be integers when is_p_value=False assert torch.allclose(out, torch.round(out)) - def test_learnedround_load_dict(self, learnedround_float_to_int_impl): - config.IGNORE_MISSING_KEYS = True - - impl, _ = learnedround_float_to_int_impl - quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) - fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) - try: - quant_conv.load_state_dict(fp_conv.state_dict()) - except RuntimeError as e: - pytest.fail(str(e)) - - -class TestAutoRound(): + @given( + learned_round_zeta=floats(min_value=0.0, max_value=3.0), + learned_round_gamma=floats(min_value=-3.0, max_value=-0.05), + value=floats(min_value=-5.0, max_value=5.0), + ) + def test_learnedround_float_to_int_impl_hard_sigmoid( + self, learned_round_zeta, learned_round_gamma, value): + value = torch.tensor([value], dtype=torch.float32) + weight = torch.zeros_like(value) + # Initialise learned round script module + learned_round_hard_sigmoid = LearnedRoundHardSigmoid( + learned_round_zeta=learned_round_zeta, + learned_round_gamma=learned_round_gamma, + ) + value_eval = learned_round_hard_sigmoid(value, training=False) + value_train = learned_round_hard_sigmoid(value, training=True) + + out_eval = weight + value_eval + out_train = weight + (value_train > 0.5) + + assert torch.allclose(out_eval, out_train) @pytest_cases.fixture() - def autoround_float_to_int_impl(self): + @pytest_cases.parametrize('impl', LEARNEDROUND_IMPL) + def learnedround_float_to_int_impl(self, impl): sample_weight = torch.randn(OUT_CH, IN_CH, KERNEL_SIZE, KERNEL_SIZE) - impl = AutoRoundSte(torch.full(sample_weight.shape, 0.)) - - # Simulate learned parameter, values should be in the interval (-0.5, 0.5) - impl.value.data = torch.rand_like(impl.value) * 0.5 - return impl, sample_weight - - def test_autoround(self, autoround_float_to_int_impl): - impl, sample_weight = autoround_float_to_int_impl + impl = LearnedRoundSte(impl, torch.full(sample_weight.shape, 0.)) - out = impl(sample_weight) - # Check that all values are integers - assert torch.allclose(out, torch.round(out)) - # Check that the values differ by at most 1 unit - assert torch.all(torch.abs(sample_weight - out) < 1) + # Simulate learned parameter + value = torch.randn_like(impl.value) + impl.value.data = value + return impl, sample_weight, value - def test_autoround_load_dict(self, autoround_float_to_int_impl): + def test_learnedround_load_dict(self, learnedround_float_to_int_impl): config.IGNORE_MISSING_KEYS = True - impl, _ = autoround_float_to_int_impl + impl, _ = learnedround_float_to_int_impl quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) try: @@ -90,11 +117,11 @@ def test_autoround_load_dict(self, autoround_float_to_int_impl): except RuntimeError as e: pytest.fail(str(e)) - def test_autoround_edge_cases(self): - sample_weight = torch.tensor([-1.000, -0.500, 0.000, 0.500, 1.000]) - impl_data = torch.tensor([-0.500, 0.500, 0.000, -0.500, 0.500]) - impl = AutoRoundSte(impl_data) + def test_learnedround_state_dict(self, learnedround_float_to_int_impl): + impl, _, value = learnedround_float_to_int_impl + state_dict = impl.state_dict() - out = impl(sample_weight) - # Check that all values are integers - assert torch.allclose(out, torch.tensor([-2.000, 0.000, 0.000, 0.000, 2.000])) + # Verify that the state dict contains the entry corresponding to the + # learnable round parameter. + assert len(state_dict.keys()) == 1 + assert torch.allclose(state_dict["value"], value) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index 1a9157214..c3fd6a82a 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -174,14 +174,18 @@ def float_tensor_random_size_st( @st.composite def two_float_tensor_random_shape_st( - draw, min_dims=1, max_dims=4, max_size=3, width=FP32_BIT_WIDTH): + draw, min_dims=1, max_dims=4, max_size=3, min_val=None, max_val=None, width=FP32_BIT_WIDTH): """ Generate a tuple of float tensors of the same random shape. """ shape = draw(random_tensor_shape_st(min_dims, max_dims, max_size)) size = reduce(mul, shape, 1) - float_list1 = draw(st.lists(float_st(width=width), min_size=size, max_size=size)) - float_list2 = draw(st.lists(float_st(width=width), min_size=size, max_size=size)) + float_list1 = draw( + st.lists( + float_st(min_val=min_val, max_val=max_val, width=width), min_size=size, max_size=size)) + float_list2 = draw( + st.lists( + float_st(min_val=min_val, max_val=max_val, width=width), min_size=size, max_size=size)) tensor1 = torch.tensor(float_list1).view(shape) tensor2 = torch.tensor(float_list2).view(shape) return tensor1, tensor2 diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index d4a7a6424..b1e61cc3e 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -40,6 +40,8 @@ POSSIBILITY OF SUCH DAMAGE. """ +from copy import deepcopy +from itertools import product import math import sys from typing import List, Union @@ -52,51 +54,15 @@ import torch from torch.nn import Parameter import torch.nn as nn -from torch.optim.lr_scheduler import ConstantLR -from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import PolynomialLR -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.optim.lr_scheduler import StepLR -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_device_type import onlyCUDA -from torch.testing._internal.common_device_type import tol -from torch.testing._internal.common_device_type import toleranceOverride -from torch.testing._internal.common_optimizers import DecorateInfo -from torch.testing._internal.common_optimizers import optim_error_inputs_func_sgd -from torch.testing._internal.common_optimizers import optim_inputs_func_sgd -from torch.testing._internal.common_optimizers import OptimizerErrorEnum -from torch.testing._internal.common_optimizers import OptimizerInfo -from torch.testing._internal.common_optimizers import optims -from torch.testing._internal.common_optimizers import skipIfTorchDynamo -from torch.testing._internal.common_utils import markDynamoStrictTest -from torch.testing._internal.common_utils import parametrize -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.common_utils import skipIfTorchDynamo -from torch.testing._internal.common_utils import TEST_WITH_TORCHDYNAMO -from torch.testing._internal.common_utils import TestCase - -from brevitas.graph.calibrate import bias_correction_mode -from brevitas.graph.calibrate import calibration_mode -from brevitas.graph.calibrate import disable_return_quant_tensor -from brevitas.graph.calibrate import DisableEnableQuantization -from brevitas.graph.calibrate import load_quant_model_mode -from brevitas.graph.calibrate import restore_return_quant_tensor -from brevitas.inject.enum import RestrictValueType -import brevitas.nn as qnn + from brevitas.optim.sign_sgd import SignSGD -from brevitas.quant import Int8ActPerTensorFixedPoint -from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat -from brevitas.quant.scaled_int import Int8ActPerTensorFloat -from brevitas.quant_tensor import QuantTensor -# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations -from brevitas.utils.torch_utils import kthvalue -from tests.brevitas.hyp_helper import float_tensor_random_size_st from tests.conftest import SEED torch.manual_seed(SEED) +from torch.testing._internal.common_optimizers import OptimizerInput + REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) @@ -105,54 +71,22 @@ REFERENCE_WEIGHTS_SIGN_GRAD = torch.tensor([[1.0000, 0.0000, 1.0000], [-1.0000, -1.0000, -1.0000], [1.0000, -1.0000, -1.0000]]) -optim_db: List[OptimizerInfo] = [ - OptimizerInfo( - SignSGD, - optim_inputs_func=optim_inputs_func_sgd, - scheduler_inputs=( - [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], - [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)], - [ - lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4),], - [ - lambda opt: StepLR(opt, gamma=0.99, step_size=10), - lambda opt: ExponentialLR(opt, gamma=0.99), - lambda opt: ReduceLROnPlateau(opt),], - [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], - [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], - [ - lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt),], - ), - optim_error_inputs_func=optim_error_inputs_func_sgd, - supported_impls=("foreach", "differentiable", "fused"), - supports_sparse=True, - metadata_for_sparse=( - { - "lr": 4.8e-3, - "maximize": False, - "momentum": 0, - "nesterov": False, - "weight_decay": 0,}, - [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], - ), - supports_fused_on=( - "cpu", - "cuda", - "mps", - ), - skips=(), - ),] - - -@markDynamoStrictTest -class TestOptimSignSGD(TestCase): - - @parametrize("lr", [0.1]) - @optims(optim_db, dtypes=[torch.float32]) - def test_sign_sgd_update(self, device, dtype, optim_info, lr): - optim_cls = optim_info.optim_cls +OPTIMIZER_KWARGS = [{}, {"maximize": True}, {"lr": 1e-2}, {"lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float16, torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestOptimSignSGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("lr", [0.1]) + def test_sign_sgd_single_update(self, device, dtype, lr): # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) # Initialize tensors to compute expected result @@ -160,7 +94,7 @@ def test_sign_sgd_update(self, device, dtype, optim_info, lr): weight_grad = REFERENCE_WEIGHTS_GRAD.to(device=device, dtype=dtype) weight_sign_grad = REFERENCE_WEIGHTS_SIGN_GRAD.to(device=device, dtype=dtype) - optimizer = optim_cls([weights], lr=lr) + optimizer = SignSGD([weights], lr=lr) # Perform a SignSGD update optimizer.zero_grad() @@ -169,142 +103,72 @@ def test_sign_sgd_update(self, device, dtype, optim_info, lr): assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) - @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None], - dtypes=[torch.float32]) - def test_errors(self, device, dtype, optim_info): - optim_cls = optim_info.optim_cls - error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype) - - for error_input in error_inputs: - optim_input = error_input.optimizer_error_input - params, kwargs = optim_input.params, optim_input.kwargs - if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR: - if issubclass(error_input.error_type, Warning): - with self.assertWarnsRegex(error_input.error_type, error_input.error_regex): - optim_cls(params, **kwargs) - else: - with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): - optim_cls(params, **kwargs) - elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR: - optim = optim_cls(params, **kwargs) - if issubclass(error_input.error_type, Warning): - with self.assertWarnsRegex(error_input.error_type, error_input.error_regex): - optim.step() - else: - with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): - optim.step() + from torch.testing._internal.common_optimizers import optims + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = SignSGD + weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) + bias = Parameter(torch.randn((10), device=device, dtype=dtype)) + input = torch.randn(5, device=device, dtype=dtype) + + optimizer = optim_cls([weight, bias], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(input) + bias).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + print(bias) + if scheduler is not None: + scheduler.step() + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value + + @pytest.mark.skipif( + torch.cuda.device_count() <= 1, reason="At least two GPUs are required for this test.") + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + @pytest_cases.parametrize("dtype", [torch.float16, torch.float32]) + def test_forloop_goes_right_direction_multigpu( + self, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = SignSGD + # Learnable parameters + weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) + bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) + input = torch.randn(5, device="cuda:0", dtype=dtype) + + optimizer = optim_cls([weight, bias], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight.mv(input).cuda(1) + bias).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + + if scheduler is not None: + scheduler.step() + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value else: - raise NotImplementedError(f"Unknown error type {error_input.error_on}") - - @parametrize("contiguous", [True, False]) - @parametrize("with_lrsched", [True, False]) - @optims(optim_db, dtypes=[torch.float32]) - def test_forloop_goes_right_direction( - self, device, dtype, optim_info, contiguous, with_lrsched): - optim_cls = optim_info.optim_cls - schedulers_constructors = (optim_info.scheduler_inputs if with_lrsched else [None]) - - for schedulers_constructor in schedulers_constructors: - # with tensor LR we need fresh inputs for each scheduler - # or mutating it will carry across iters - optim_inputs = optim_info.optim_inputs_func(device=device) - for optim_input in optim_inputs: - if "foreach" in optim_info.supported_impls: - optim_input.kwargs["foreach"] = False # force forloop - if contiguous: - weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) - bias = Parameter(torch.randn((10), device=device, dtype=dtype)) - else: - weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]) - bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0]) - input = torch.randn(5, device=device, dtype=dtype) - - optimizer = optim_cls([weight, bias], **optim_input.kwargs) - schedulers = [ - s(optimizer) - for s in (schedulers_constructor if schedulers_constructor else [])] - - def closure(): - optimizer.zero_grad() - loss = (weight.mv(input) + bias).pow(2).sum() - loss.backward() - if optim_info.only_supports_sparse_grads: - # For this test, we naively convert the Tensor layout, which we know does - # NOT represent the expected use case for optims like SparseAdam! - weight.grad = weight.grad.to_sparse() - bias.grad = bias.grad.to_sparse() - return loss - - initial_value = closure().item() - for _ in range(20): - if optim_info.step_requires_closure: - loss = optimizer.step(closure) - else: - loss = closure() - optimizer.step() - - for scheduler in schedulers: - if isinstance(scheduler, ReduceLROnPlateau): - scheduler.step(loss) - else: - scheduler.step() - - if optim_input.kwargs.get("maximize", False): - self.assertGreater(closure().item(), initial_value) - else: - self.assertLess(closure().item(), initial_value) - - @onlyCUDA - @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") - @parametrize("with_lrsched", [True, False]) - @optims(optim_db, dtypes=[torch.float32]) - def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info, with_lrsched): - optim_cls = optim_info.optim_cls - schedulers_constructors = (optim_info.scheduler_inputs if with_lrsched else [None]) - for schedulers_constructor in schedulers_constructors: - # We need a fresh set of inputs if we have a tensor LR - # to not carry mutations across iterations. - optim_inputs = optim_info.optim_inputs_func(device=device) - for optim_input in optim_inputs: - if "foreach" in optim_info.supported_impls: - optim_input.kwargs["foreach"] = False # force forloop - - weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) - bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) - inpt = torch.randn(5, device="cuda:0", dtype=dtype) - - optimizer = optim_cls([weight, bias], **optim_input.kwargs) - schedulers = [ - s(optimizer) - for s in (schedulers_constructor if schedulers_constructor else [])] - - def closure(): - optimizer.zero_grad() - loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum() - loss.backward() - if optim_info.only_supports_sparse_grads: - # For this test, we naively convert the Tensor layout, which we know does - # NOT represent the expected use case for optims like SparseAdam! - weight.grad = weight.grad.to_sparse() - bias.grad = bias.grad.to_sparse() - return loss - - initial_value = closure().item() - for _ in range(20): - loss = optimizer.step(closure) - for scheduler in schedulers: - if isinstance(scheduler, ReduceLROnPlateau): - scheduler.step(loss) - else: - scheduler.step() - - if optim_input.kwargs.get("maximize", False): - self.assertGreater(closure().item(), initial_value) - else: - self.assertLess(closure().item(), initial_value) - - -instantiate_device_type_tests(TestOptimSignSGD, globals(), allow_mps=True) - -if __name__ == "__main__": - run_tests() + assert closure().item() < initial_value diff --git a/tests/brevitas_examples/test_imagenet.py b/tests/brevitas_examples/test_imagenet.py deleted file mode 100644 index e4e117f1a..000000000 --- a/tests/brevitas_examples/test_imagenet.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from hypothesis import given -import pytest -import pytest_cases -from pytest_cases import fixture -import torch -import torch.nn as nn - -from brevitas_examples.imagenet_classification.ptq.utils import get_torchvision_model - -DTYPE = torch.float32 - - -class TestImageNet: - - @fixture - def model(): - # Get the model from torchvision - model = get_torchvision_model("resnet18") - model = model.to(DTYPE) - model.eval() - - return model - - def test_model_can_be_loaded(model): - print(f"The model class IS: {type(model)}") - assert False - - -if __name__ == "__main__": - # Run pytest on the current file - pytest.main(["-s", __file__]) diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index 8c2a62c14..e9501668e 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -1,6 +1,9 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Any, Union + +from accelerate.utils.operations import send_to_device from hypothesis import given import pytest import pytest_cases @@ -12,109 +15,147 @@ from torch.utils.data import Dataset from brevitas import config -from brevitas.core.function_wrapper.learned_round import AutoRoundSte from brevitas.core.function_wrapper.learned_round import LearnedRoundSte import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor.base_quant_tensor import QuantTensor from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AdaRoundLoss from brevitas_examples.common.learned_round.learned_round_method import AutoRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRoundLoss from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks -from brevitas_examples.imagenet_classification.ptq.learned_round_utils import \ - LearnedRoundVisionUtils +from brevitas_examples.common.learned_round.learned_round_optimizer import save_inputs_output config.IGNORE_MISSING_KEYS = True -class TestLearnedRound: +class QuantBlock(nn.Module): - @fixture - def quant_model(): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = qnn.QuantLinear(in_features=in_features, out_features=hidden_dim) + self.layer2 = qnn.QuantLinear(in_features=hidden_dim, out_features=out_features) + self.relu = qnn.QuantReLU(return_quant_tensor=True) - class QuantBlock(nn.Module): + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) - def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: - super().__init__() - self.layer1 = qnn.QuantLinear(in_features=in_features, out_features=hidden_dim) - self.layer2 = qnn.QuantLinear(in_features=hidden_dim, out_features=out_features) - self.relu = qnn.QuantReLU(return_quant_tensor=True) - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - out = self.layer1(x) - out = self.relu(out) - out = self.layer2(out) - return self.relu(out) +class TestQuantModel(nn.Module): - class TestQuantModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = QuantBlock( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = QuantBlock( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) - def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: - super().__init__() - self.in_proj_mlp = QuantBlock( - in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) - self.hidden_mlp = QuantBlock( - in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) - self.out_proj_mlp = QuantBlock( - in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x, block1_kwarg=0., **kwargs) + out = self.hidden_mlp(out, block2_kwarg=0., **kwargs) + return self.out_proj_mlp(out, block3_kwarg=0., **kwargs) - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - out = self.in_proj_mlp(x) - out = self.hidden_mlp(out) - return self.out_proj_mlp(out) - return TestQuantModel(in_features=2, out_features=1, hidden_dim=4) +class Block(nn.Module): - @fixture - def model(): + def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: + super().__init__() + self.layer1 = nn.Linear(in_features=in_features, out_features=hidden_dim) + self.layer2 = nn.Linear(in_features=hidden_dim, out_features=out_features) + self.relu = F.relu - class Block(nn.Module): + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + return self.relu(out) - def __init__(self, in_features: int, hidden_dim: int, out_features: int) -> None: - super().__init__() - self.layer1 = nn.Linear(in_features=in_features, out_features=hidden_dim) - self.layer2 = nn.Linear(in_features=hidden_dim, out_features=out_features) - self.relu = F.relu - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - out = self.layer1(x) - out = self.relu(out) - out = self.layer2(out) - return self.relu(out) +class TestModel(nn.Module): - class TestModel(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: + super().__init__() + self.in_proj_mlp = Block( + in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) + self.hidden_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) + self.out_proj_mlp = Block( + in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) - def __init__(self, in_features: int, out_features: int, hidden_dim: int) -> None: - super().__init__() - self.in_proj_mlp = Block( - in_features=in_features, hidden_dim=hidden_dim, out_features=hidden_dim) - self.hidden_mlp = Block( - in_features=hidden_dim, hidden_dim=hidden_dim, out_features=hidden_dim) - self.out_proj_mlp = Block( - in_features=hidden_dim, hidden_dim=hidden_dim, out_features=out_features) + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + out = self.in_proj_mlp(x, block1_kwarg=0., **kwargs) + out = self.hidden_mlp(out, block2_kwarg=0., **kwargs) + return self.out_proj_mlp(out, block3_kwarg=0., **kwargs) - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - out = self.in_proj_mlp(x) - out = self.hidden_mlp(out) - return self.out_proj_mlp(out) - return TestModel(in_features=2, out_features=1, hidden_dim=4) +class TestDataset(Dataset): - @fixture - def data_loader(): + def __init__(self): + self.data = [ + { + "x": torch.tensor([1.0, 2.0]), + "tensor": torch.tensor([0.0]), + "bool": True, + "float": 0.0,},] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class TestCache: + + def __init__(self) -> None: + self.args = [] + self.kwargs = [] + self.output = [] + + def __len__(self) -> int: + return len(self.args) - class TestDataset(Dataset): + def store_inputs(self, args: Any, kwargs: Any) -> None: + self.args.append(args) + self.kwargs.append(kwargs) - def __init__(self): - self.data = torch.tensor([[1.0, 2.0]]) - self.labels = torch.tensor([0]) + def store_output(self, output: Any) -> None: + self.output.append(output) - def __len__(self): - return len(self.data) + def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: + pass - def __getitem__(self, idx): - return self.data[idx], self.labels[idx] + def initialize_cache(self) -> None: + pass + + def clear_cache(self) -> None: + pass + + def reset_cache(self) -> None: + pass + + def cache_to_dataset(self) -> Dataset: + pass + + def collate_fn(self, batch: Any) -> Any: + pass + + +class TestLearnedRound: + + @fixture + def quant_model(): + return TestQuantModel(in_features=2, out_features=1, hidden_dim=4) + + @fixture + def model(): + return TestModel(in_features=2, out_features=1, hidden_dim=4) + @fixture + def data_loader(): return DataLoader(TestDataset(), batch_size=1, shuffle=False) def test_get_blocks(self, quant_model: nn.Module): @@ -144,185 +185,143 @@ def _is_layer(module: nn.Module, module_name: str) -> bool: assert expected_layers == layers @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") - # NOTE: DataSaverHook always returns a torch.Tensor for the input, while for the output it can be either a torch.Tensor or - # a QuantTensor. Is this expected behaviour? For that reason, the argument _assert_type is included in _aux_check_tensors. - # Also, returning an empty list for the save_inp_out_data does not seem very natural, considering a tensors if the appropiate - # store option is activated. - @pytest.mark.parametrize("store_input", [True, False]) - @pytest.mark.parametrize("store_out", [True, False]) + @pytest.mark.parametrize("store_inputs", [True, False]) + @pytest.mark.parametrize("store_output", [True, False]) @pytest.mark.parametrize("keep_gpu", [True, False]) @pytest.mark.parametrize("disable_quant", [True, False]) def test_save_inp_out_data( - self, model, quant_model, data_loader, store_input, store_out, keep_gpu, disable_quant): - # Initialise utils to save tensors - learned_round_vision_utils = LearnedRoundVisionUtils() + self, model, quant_model, data_loader, store_inputs, store_output, keep_gpu, + disable_quant): + + def _compare_tensors(cache_tensor, gt_tensor, disable_quant, keep_gpu): + # The elements should be of the same type + assert isinstance(cache_tensor, torch.Tensor) if disable_quant else isinstance( + cache_tensor, QuantTensor) + # If they are QuantTensors extract their data + if isinstance(cache_tensor, QuantTensor): + cache_tensor, gt_tensor = cache_tensor.value, gt_tensor.value + # Verify that the tensor is in GPU if keep_gpu=True + assert keep_gpu == cache_tensor.is_cuda + # Make sure tensors are in the same device before comparison + cache_tensor = cache_tensor.cpu() + gt_tensor = gt_tensor.cpu() + # Check that their contents match + assert torch.allclose(cache_tensor, gt_tensor) + # Return True to be used in the helper + return True + + def model_forward(model, inputs): + device = next(model.parameters()).device + inputs = send_to_device(inputs, device) + model(**inputs) + # Make sure that the quant and FP models share the same weights quant_model.load_state_dict(model.state_dict()) + # Prepare models model.eval() model = model.cuda() - quant_model.eval() quant_model = quant_model.cuda() - # Retrieve module from quant_model - module = quant_model.hidden_mlp - - cache_quant_partial_input = [] - cache_quant_partial_output = [] - - cache_fp_partial_input = [] - cache_fp_partial_output = [] - - def _aux_check_tensors( - result_tensor, expected_tensor, keep_gpu, disable_quant, assert_type=False): - # Verify that tensor is of the appropiate type - if assert_type: - assert isinstance(result_tensor, torch.Tensor if disable_quant else QuantTensor) - # Extract value tensors - if isinstance(result_tensor, QuantTensor): - result_tensor, expected_tensor = result_tensor.value, expected_tensor.value - # Verify that tensor is in appropiate device - assert result_tensor.is_cuda == keep_gpu - # Make sure tensors are in the same device before comparison - if not keep_gpu: - expected_tensor = expected_tensor.cpu() - - assert torch.allclose(result_tensor, expected_tensor) + device = next(quant_model.parameters()).device + # Compute ground-truth inputs/outputs + fp_args, fp_kwargs, fp_outs = [], [], [] + quant_args, quant_kwargs, quant_outs = [], [], [] # Compute ground truths inputs and outputs with torch.no_grad(): - for batch_data, _ in data_loader: - batch_data = batch_data.cuda() + for inputs in data_loader: + inputs = send_to_device(inputs, device) + kwargs = {k: v for k, v in inputs.items() if k != "x"} # Compute quant inputs to module - quant_partial_input = quant_model.in_proj_mlp(batch_data) - cache_quant_partial_input.append(quant_partial_input) + quant_arg = quant_model.in_proj_mlp(**inputs) + quant_kwarg = {"block2_kwarg": 0.0, **kwargs} # Compute quant outputs of module - quant_partial_output = quant_model.hidden_mlp(quant_partial_input) - cache_quant_partial_output.append(quant_partial_output) - - # Compute FP inputs to module - fp_partial_input = model.in_proj_mlp(batch_data) - cache_fp_partial_input.append(fp_partial_input) - # Compute FP outputs of module - fp_partial_output = model.hidden_mlp(fp_partial_input) - cache_fp_partial_output.append(fp_partial_output) - - # Inputs and outputs are concatenated along the batch dimension. - # See https://github.com/quic/aimet/blob/7c9eded51e3d8328746e7ba4cf68c7162f841712/TrainingExtensions/torch/src/python/aimet_torch/v1/adaround/activation_sampler.py#L231 - cache_quant_partial_input = torch.cat(cache_quant_partial_input, dim=0) - cache_quant_partial_output = torch.cat(cache_quant_partial_output, dim=0) - - cache_fp_partial_input = torch.cat(cache_fp_partial_input, dim=0) - cache_fp_partial_output = torch.cat(cache_fp_partial_output, dim=0) - - # Retrieve input and output data - input_data, out_data = learned_round_vision_utils._save_inp_out_data(quant_model, module, data_loader, store_input, store_out, keep_gpu, disable_quant) - # Verify that empty lists are returned - if store_input: - if disable_quant: - _aux_check_tensors( - input_data, fp_partial_input, keep_gpu, disable_quant, assert_type=True) - else: - _aux_check_tensors(input_data, quant_partial_input, keep_gpu, disable_quant) - else: - assert len(input_data) == 0 - - if store_out: - if disable_quant: - _aux_check_tensors(out_data, fp_partial_output, keep_gpu, disable_quant) - else: - _aux_check_tensors( - out_data, quant_partial_output, keep_gpu, disable_quant, assert_type=True) + quant_out = quant_model.hidden_mlp(quant_arg, **quant_kwarg) + + quant_args.append((quant_arg,)) + quant_kwargs.append(quant_kwarg) + quant_outs.append(quant_out) + + # Compute quant inputs to module + fp_arg = model.in_proj_mlp(**inputs) + fp_kwarg = {"block2_kwarg": 0.0, **kwargs} + # Compute quant outputs of module + fp_out = model.hidden_mlp(fp_arg, **fp_kwarg) + + fp_args.append((fp_arg,)) + fp_kwargs.append(fp_kwarg) + fp_outs.append(fp_out) + + # Prepare to capture inputs/outputs using DataSaverHook + cache = TestCache() + + # Retrieve module from quant_model + module = quant_model.hidden_mlp + + # Make call to save inputs/outputs + save_inputs_output( + model=quant_model, + model_forward=model_forward, + module=module, + dataloader=data_loader, + cache=cache, + store_inputs=store_inputs, + store_output=store_output, + keep_gpu=keep_gpu, + disable_quant=disable_quant, + ) + + # Verify that the lengths of the lists match + if store_inputs: + assert len(cache.args) == len(fp_args) and len(cache.kwargs) == len(fp_kwargs) else: - assert len(out_data) == 0 + assert len(cache.args) == 0 and len(cache.kwargs) == 0 - @pytest.mark.parametrize( - "learned_round_class, rounding_mode, float_to_int_impl", - [(AutoRound, "AUTO_ROUND", AutoRoundSte), (AdaRound, "LEARNED_ROUND", LearnedRoundSte)]) - def test_insert_learned_round_quantizer( - self, quant_model, learned_round_class, rounding_mode, float_to_int_impl): + if store_output: + assert len(cache.output) == len(fp_outs) + else: + assert len(cache.output) == 0 + + # Verify that the arguments are the same + for cache_arg, gt_arg in zip(cache.args, fp_args if disable_quant else quant_args): + _compare_tensors(cache_arg[0], gt_arg[0], disable_quant, keep_gpu) + + for cache_kwarg, gt_kwarg in zip(cache.kwargs, fp_kwargs if disable_quant else quant_kwargs): + # Compare the contents within each dictionary + same_contents = all(( + torch.allclose(cache_kwarg.get(gt_kwarg_k, None).cpu(), gt_kwarg_v.cpu( + )) if isinstance(gt_kwarg_v, torch.Tensor) else cache_kwarg.get(gt_kwarg_k, None) == + gt_kwarg_v) for gt_kwarg_k, + gt_kwarg_v in gt_kwarg.items()) + # Verify that the dictionaries have the same keys and content + assert set(cache_kwarg.keys()) == set(gt_kwarg.keys()) and same_contents + + # Verify that the outputs are the same + for cache_output, gt_output in zip(cache.output, fp_outs if disable_quant else quant_outs): + _compare_tensors(cache_output, gt_output, disable_quant, keep_gpu) + + @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) + def test_insert_learned_round_quantizers(self, quant_model, learned_round): block = quant_model.in_proj_mlp - learned_round = learned_round_class(iters=100) - learned_round._insert_learned_round_quantizer(block) + learned_round.insert_learned_round_quantizers(block) for module in block.modules(): if hasattr(module, "weight_quant"): - assert module.weight_quant.rounding_mode == rounding_mode + assert module.weight_quant.rounding_mode == "LEARNED_ROUND" assert isinstance( - module.weight_quant.tensor_quant.int_quant.float_to_int_impl, float_to_int_impl) + module.weight_quant.tensor_quant.int_quant.float_to_int_impl, LearnedRoundSte) - @pytest.mark.parametrize("learned_round_class", [AutoRound, AdaRound]) + @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) @pytest.mark.parametrize( "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) - def test_find_learned_round_modules( - self, quant_model, learned_round_class, block_strs, num_round_modules): - learned_round = learned_round_class(iters=100) + def test_return_learned_round_quantizers( + self, quant_model, learned_round, block_strs, num_round_modules): # Inject quantizers in quant model for block_str in block_strs: block = getattr(quant_model, block_str) - learned_round._insert_learned_round_quantizer(block) - learned_round_modules = learned_round._find_learned_round_modules(quant_model) + learned_round.insert_learned_round_quantizers(block) + learned_round_modules = learned_round.return_learned_round_quantizers(quant_model) assert len(learned_round_modules) == num_round_modules - - @pytest.mark.parametrize( - "learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss)]) - @pytest.mark.parametrize( - "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), - (["in_proj_mlp", "out_proj_mlp"], 4)]) - def test_learned_round_iter_blockwise( - self, - quant_model, - learned_round_class, - learned_round_loss_class, - block_strs, - num_round_modules): - # Retrieve blocks from quant model - blocks = [getattr(quant_model, block_str) for block_str in block_strs] - learned_round = learned_round_class(iters=100) - - # Counters to verify that the generators returns the appropiate number of items - blocks_count = 0 - learned_round_modules_count = 0 - - for (block, block_loss, - block_learned_round_modules) in learned_round.learned_round_iterator(blocks): - assert isinstance(block_loss, learned_round_loss_class) - - for learned_round_module in block_learned_round_modules: - for params in learned_round_module.parameters(): - assert params.requires_grad - - blocks_count += 1 - learned_round_modules_count += len(block_learned_round_modules) - - assert blocks_count == len(blocks) - assert learned_round_modules_count == num_round_modules - - @pytest.mark.parametrize( - "learned_round_class, learned_round_loss_class", [(AutoRound, AutoRoundLoss), - (AdaRound, AdaRoundLoss)]) - def test_learned_round_iter_layerwise( - self, quant_model, learned_round_class, learned_round_loss_class): - # Retrieve blocks from quant model - blocks = [module for module in quant_model.modules() if isinstance(module, QuantWBIOL)] - learned_round = learned_round_class(iters=100) - - # Counters to verify that the generators returns the appropiate number of items - blocks_count = 0 - learned_round_modules_count = 0 - - for (block, block_loss, - block_learned_round_modules) in learned_round.learned_round_iterator(blocks): - assert isinstance(block_loss, learned_round_loss_class) - - for learned_round_module in block_learned_round_modules: - for params in learned_round_module.parameters(): - assert params.requires_grad - - blocks_count += 1 - learned_round_modules_count += len(block_learned_round_modules) - - assert blocks_count == len(blocks) - assert learned_round_modules_count == len(blocks) From a0d3d54fdb677adb629ffa375eeec042de81e196 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 21 Nov 2024 01:31:00 +0000 Subject: [PATCH 22/48] Enable scale tuning in learned round --- .../learned_round/learned_round_optimizer.py | 80 +++++++++++++++++-- .../llm/llm_quant/learned_round_utils.py | 2 + src/brevitas_examples/llm/main.py | 16 +++- 3 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index cf9dde3ce..0f2939acc 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -201,6 +201,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer +from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import RandomSampler from tqdm import tqdm @@ -211,6 +212,7 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.optim.sign_sgd import SignSGD +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.learned_round.learned_round_method import LearnedRound @@ -238,6 +240,24 @@ def _get_blocks(module: nn.Module): return blocks +def return_scale_parameters(block: nn.Module) -> List[nn.Parameter]: + + scale_parameters = [] + + def _get_scale_parameters(module: nn.Module): + for module_child in module.children(): + if isinstance(module, WeightQuantProxyFromInjectorBase): + for submodule_name, submodule in module_child.named_parameters(): + if submodule_name.endswith('scaling_impl.value'): + scale_parameters.append(submodule) + else: + _get_scale_parameters(module_child) + + # Run recursion from block + _get_scale_parameters(block) + return scale_parameters + + class StopFwdException(Exception): """Used to throw and catch an exception to stop traversing the graph.""" pass @@ -350,10 +370,12 @@ def __init__( learned_round_loss_class: Type[LearnedRoundLoss], *, optimizer_class: Type[Optimizer] = SignSGD, + scale_optimizer_class: Type[Optimizer] = SGD, lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, optimizer_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, + learn_scale: bool = False, use_best_model: bool = True, use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, @@ -365,10 +387,12 @@ def __init__( ) -> None: self.learned_round = learned_round self.optimizer_class = optimizer_class + self.scale_optimizer_class = scale_optimizer_class self.lr_scheduler_class = lr_scheduler_class self.optimizer_lr = optimizer_lr self.batch_size = batch_size self.iters = iters + self.learn_scale = learn_scale self.use_best_model = use_best_model self.use_amp = use_amp self.amp_dtype = amp_dtype @@ -399,11 +423,25 @@ def _collect_round_params(self, block: nn.Module) -> Dict: params[n] = copy.deepcopy(m.state_dict()) return params - def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: - optimizer.step() - optimizer.zero_grad() - if lr_scheduler: - lr_scheduler.step() + def _optim_step(self, *optimizers: Optimizer) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + + def _lr_sched_step(self, *lr_schedulers: LRScheduler) -> None: + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() + + def _step(self, optimizers: List[Optimizer], lr_schedulers: List[LRScheduler]) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() def _populate_cache( self, @@ -448,6 +486,7 @@ def _optimize_learned_round_block( cache: Cache, block_loss: LearnedRoundLoss, block_forward: Callable, + scale_params: Optional[nn.Parameter] = None, ) -> Tuple[float, float, int]: # Move block to GPU if available if torch.cuda.is_available(): @@ -474,6 +513,22 @@ def _optimize_learned_round_block( self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) if self.lr_scheduler_class else None) + # Initialize optimizer/LR scheduler for the scale parameters if enabled + if self.learn_scale and scale_params is not None: + optimizer_scale = self.scale_optimizer_class( + scale_params, + lr=self.optimizer_lr, + momentum=0.9, + **self.optimizer_kwargs, + ) + lr_scheduler_scale = ( + self.lr_scheduler_class( + optimizer_scale, start_factor=1, end_factor=0, total_iters=600) + if self.lr_scheduler_class else None) + else: + optimizer_scale = None + lr_scheduler_scale = None + # Variables needed for printing best_loss = torch.finfo(torch.float).max init_loss = -1.0 @@ -482,7 +537,7 @@ def _optimize_learned_round_block( # Dictionary to store the rounding parameters yielding the lowest # training loss optimal_rounding_params = {} - + torch.autograd.set_detect_anomaly(True) n_samples = len(cache) pbar = tqdm(range(self.iters), desc='') for i in pbar: @@ -512,7 +567,8 @@ def _optimize_learned_round_block( # Scale loss and perform gradient step loss = loss * self.loss_scaling_factor loss.backward() - self._step(optimizer, lr_scheduler) + self._optim_step(optimizer, optimizer_scale) + self._lr_sched_step(lr_scheduler, lr_scheduler_scale) # Update progress bar pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) @@ -696,6 +752,9 @@ def apply_learned_round( # Remove hooks needed to offload the model blocks to cpu remove_hooks(model) + # Retrieve scales + scale_params = return_scale_parameters(block) + # The parameters of the block that are not part of the rounding quantizers # need to be frozen, as only the rounding needs to be optimized. block.eval() @@ -707,6 +766,10 @@ def apply_learned_round( block_learned_round_module.train() for params in block_learned_round_module.parameters(): params.requires_grad = True + # As well as the scale parameters, if enabled + if self.learn_scale: + for params in scale_params: + params.requires_grad = True # Move block to GPU if available if torch.cuda.is_available(): @@ -729,6 +792,7 @@ def apply_learned_round( cache=cache, block_loss=block_loss, block_forward=block_forward, + scale_params=scale_params, ) print( @@ -741,6 +805,8 @@ def apply_learned_round( block_learned_round_module.eval() for params in block_learned_round_module.parameters(): params.requires_grad = False + for params in scale_params: + params.requires_grad = False # Move the block back to CPU block.cpu() diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index dd0842702..bf2e565cc 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -167,6 +167,7 @@ def apply_learned_round( lr_scheduler: Optional[str] = "linear", optimizer_lr: float = 5e-3, batch_size: int = 8, + learn_scale: bool = False, use_best_model: bool = True, use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, @@ -203,6 +204,7 @@ def apply_learned_round( optimizer_lr=optimizer_lr, batch_size=batch_size, iters=iters, + learn_scale=learn_scale, use_best_model=use_best_model, use_amp=use_amp, amp_dtype=amp_dtype, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 04385276b..bb5a6cd8b 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -371,7 +371,12 @@ def main(args): if args.learned_round: print("Applying learned round...") remove_hooks(model) - apply_learned_round(model, calibration_loader) + apply_learned_round( + model, + calibration_loader, + iters=args.learned_round_iters, + learn_scale=args.learned_round_scale, + ) print("Learned round applied.") model = offload_model(model) @@ -560,6 +565,15 @@ def parse_args(args): type=int, default=64, help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--learned-round-iters', + type=int, + default=200, + help='Number of iterations for learned round. Default: 200.') + parser.add_argument( + '--learned-round-scale', + action='store_true', + help='Learned scale factor together with round.') parser.add_argument( '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') parser.add_argument( From 6b89ccc16324024f579d4300e52294cdf0a0be72 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 21 Nov 2024 14:40:31 +0000 Subject: [PATCH 23/48] Unified learned round methods --- .../core/function_wrapper/learned_round.py | 41 ++++--- .../learned_round/learned_round_method.py | 113 ++++++++++-------- .../ptq/learned_round_utils.py | 15 +-- .../ptq/ptq_evaluate.py | 4 +- .../llm/benchmark/llm_benchmark.py | 3 +- .../llm/llm_quant/learned_round_utils.py | 8 +- src/brevitas_examples/llm/main.py | 2 +- tests/brevitas/core/test_float_to_int.py | 19 ++- .../test_learned_round_utils.py | 16 ++- 9 files changed, 123 insertions(+), 98 deletions(-) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 2d3e76aeb..cfb1cfa5c 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -25,20 +25,21 @@ class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = -0.1) -> None: super(LearnedRoundHardSigmoid, self).__init__() - self.float_to_int_ste = floor_ste - self.is_p_value = True self.learned_round_zeta = learned_round_zeta self.learned_round_gamma = learned_round_gamma @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - p = torch.sigmoid(x) + def forward(self, p: torch.Tensor) -> torch.Tensor: + p = torch.sigmoid(p) p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma p = torch.clamp(p, 0.0, 1.0) - if not training: + if not self.training: return p > 0.5 return p + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + class LearnedRoundSigmoid(brevitas.jit.ScriptModule): """ @@ -49,17 +50,19 @@ class LearnedRoundSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_temperature: float = 1.) -> None: super(LearnedRoundSigmoid, self).__init__() assert learned_round_temperature != 0, 'Temperature should be different than 0' - self.float_to_int_ste = floor_ste - self.is_p_value = True self.learned_round_temperature = learned_round_temperature @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - if not training: - return x > 0 - p = torch.sigmoid(x / self.learned_round_temperature) + def forward(self, p: torch.Tensor) -> torch.Tensor: + if not self.training: + return p > 0 + p = torch.sigmoid(p / self.learned_round_temperature) return p + @brevitas.jit.script_method + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + class LearnedRoundIdentity(brevitas.jit.ScriptModule): """ @@ -69,12 +72,14 @@ class LearnedRoundIdentity(brevitas.jit.ScriptModule): def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() - self.float_to_int_ste = round_ste - self.is_p_value = False @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - return x + def forward(self, p: torch.Tensor) -> torch.Tensor: + return p + + @brevitas.jit.script_method + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return round_ste(x + p) class LearnedRoundSte(brevitas.jit.ScriptModule): @@ -97,12 +102,10 @@ def __init__( @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> torch.Tensor: - float_to_int_ste = self.learned_round_impl.float_to_int_ste - is_p_value = self.learned_round_impl.is_p_value - p = self.learned_round_impl(self.value, self.training) + p = self.learned_round_impl(self.value) p = self.tensor_slicer(p) p = (p.to(x.dtype)).view_as(x) - return float_to_int_ste(x) + p if is_p_value else float_to_int_ste(x + p) + return self.learned_round_impl.round_forward(x, p) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index 04e7b3818..ae6ab8392 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -3,7 +3,7 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, Generator, List, Optional, Tuple, Type +from typing import Callable, Dict, Generator, List, Optional, Tuple, Type import torch from torch import nn @@ -35,18 +35,72 @@ def format_loss_components(self, *args) -> str: pass +def learned_round_value_init_non_linear( + layer: nn.Module, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1, + **learned_round_impl_kwargs, +) -> torch.Tensor: + floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) + delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight + value = -torch.log((learned_round_zeta - learned_round_gamma) / + (delta - learned_round_gamma) - 1) + return value + + +def learned_round_value_init_linear( + layer: nn.Module, + **learned_round_impl_kwargs, +) -> torch.Tensor: + value = torch.zeros_like(layer.weight.data) + return value + + +LEARNED_ROUND_VALUE_INIT_MAP = { + LearnedRoundImplType.HARD_SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.IDENTITY.value: learned_round_value_init_linear,} + + class LearnedRound(ABC): + def __init__( + self, + learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, + learned_round_value_init_fn: Optional[Callable] = None, + **learned_round_impl_kwargs, + ) -> None: + self.learned_round_impl_type = learned_round_impl_type + self.learned_round_value_init_fn = learned_round_value_init_fn + self.learned_round_impl_kwargs = learned_round_impl_kwargs + + def learned_round_value_init( + self, + layer: nn.Module, + ) -> torch.Tensor: + # A custom initialization function for the learned round parameter can be passed + if self.learned_round_value_init_fn is not None: + return self.learned_round_value_init_fn(layer, **self.learned_round_impl_kwargs) + # If not provided, the default function, as defined in LEARNED_ROUND_VALUE_INIT_MAP + # is leveraged + return LEARNED_ROUND_VALUE_INIT_MAP[self.learned_round_impl_type.value]( + layer, **self.learned_round_impl_kwargs) + + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + value = self.learned_round_value_init(layer) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=self.learned_round_impl_type, + learned_round_init=value, + **self.learned_round_impl_kwargs, + ) + layer.weight_quant.init_tensor_quant(preserve_state_dict=True) + def insert_learned_round_quantizers(self, model: nn.Module) -> None: for module in model.modules(): if isinstance(module, QuantWBIOL) and len( self.return_learned_round_quantizers(module)) == 0: self._insert_learned_round_quantizer_to_layer(module) - module.weight_quant.init_tensor_quant(preserve_state_dict=True) - - @abstractmethod - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - pass def return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: return [module for module in block.modules() if isinstance(module, LearnedRoundSte)] @@ -80,9 +134,9 @@ def __init__( warmup: float = 0.2, decay_start: float = 0.0, **kwargs) -> None: - # AdaRound operates in a layer-wise manner, so integrity needs to be checked - assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." - assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." + # This loss operates in a layer-wise manner, so integrity needs to be checked + assert isinstance(module, QuantWBIOL), "Regularised MSE loss can only accept a single QuantWBIOL layer." + assert len(learned_round_modules) == 1, "Regularised MSE loss can only accept a single learned round module." self.weight = weight self.module = module @@ -119,33 +173,6 @@ def format_loss_components(self, loss: float, rec_loss: float, round_loss: float b) -class AdaRound(LearnedRound): - - def __init__( - self, - learned_round_zeta: float = 1.1, - learned_round_gamma: float = -0.1, - learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, - **kwargs, - ) -> None: - # Quantiser-related configuration - self.learned_round_zeta = learned_round_zeta - self.learned_round_gamma = learned_round_gamma - self.learned_round_impl_type = learned_round_impl_type - - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) - delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((self.learned_round_zeta - self.learned_round_gamma) / - (delta - self.learned_round_gamma) - 1) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=self.learned_round_impl_type, - learned_round_gamma=self.learned_round_gamma, - learned_round_zeta=self.learned_round_zeta, - learned_round_init=value) - - class MSELoss(LearnedRoundLoss): def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: @@ -157,17 +184,3 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, def format_loss_components(self, loss: float) -> str: return "Loss = {:.4f}".format(loss) - - -class AutoRound(LearnedRound): - - def __init__(self, **kwargs) -> None: - pass - - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - value = torch.zeros_like(layer.weight.data) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.IDENTITY, - learned_round_init=value, - ) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 8d83d6510..aa943f473 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -38,11 +38,11 @@ from torch.utils.data.dataloader import DataLoader from brevitas import config +from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.optim.sign_sgd import SignSGD from brevitas.quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer @@ -59,8 +59,9 @@ def is_layer(module: nn.Module, module_name: str) -> bool: LEARNED_ROUND_MAP = { - "auto_round": AutoRound, - "ada_round": AdaRound,} + "linear_round": LearnedRoundImplType.IDENTITY, + "hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID, + "sigmoid_round": LearnedRoundImplType.SIGMOID,} LEARNED_ROUND_LOSS_MAP = { "mse": MSELoss, "regularised_mse": RegularisedMSELoss,} @@ -152,7 +153,7 @@ def apply_learned_round( model: nn.Module, calibration_loader: DataLoader, iters: int = 1000, - learned_round: str = "ada_round", + learned_round: str = "hard_sigmoid_round", learned_round_loss: str = "regularised_mse", optimizer: str = "adam", lr_scheduler: Optional[str] = None, @@ -169,7 +170,7 @@ def apply_learned_round( ) -> None: if learned_round not in LEARNED_ROUND_MAP: raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LEARNED_ROUND_MAP[learned_round]() + learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: raise ValueError(f"Learned round loss {learned_round_loss} is not available.") @@ -221,4 +222,4 @@ def apply_learned_round( cache=cache, block_check_fn=block_check_fn, keep_gpu=True, - ) \ No newline at end of file + ) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 65b9e07c6..348213bdb 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -162,7 +162,7 @@ def validate_args(args): '--learned-round', default=None, type=str, - choices=[None, 'ada_round', 'auto_round'], + choices=[None, 'linear_round', 'hard_sigmoid_round', 'sigmoid_round'], help='Learned round type (default: None)') parser.add_argument( '--learned-round-loss', @@ -430,7 +430,7 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) elif args.target_backend == 'fx' or args.target_backend == 'layerwise': - if args.learned_round != "auto_round": + if args.learned_round_mode != "blockwise": model = preprocess_for_quantize( model, equalize_iters=args.graph_eq_iterations, diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index dec0e81c3..c21036be5 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -112,7 +112,8 @@ def unique(sequence): 'export_prefix': [None], # Path prefix to use for the various export flows. 'checkpoint_name': [None], # Filename to save checkpoint. 'fuse_sequences': [False], # Whether to merge the dataset sequences. - 'learned_round': [None, "auto_round"], # Whether to use learned round. If `None`, RTN is used. + 'learned_round': [None, + "linear_round"], # Whether to use learned round. If `None`, RTN is used. } parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index bf2e565cc..f402ce211 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -14,15 +14,15 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer +from brevitas.inject.enum import LearnedRoundImplType from brevitas.optim.sign_sgd import SignSGD -from brevitas_examples.common.learned_round.learned_round_method import AutoRound from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer LEARNED_ROUND_MAP = { - "auto_round": AutoRound,} + "linear_round": LearnedRoundImplType.IDENTITY,} LEARNED_ROUND_LOSS_MAP = { "mse": MSELoss,} OPTIMIZER_MAP = { @@ -161,7 +161,7 @@ def apply_learned_round( model: nn.Module, calibration_loader: DataLoader, iters: int = 200, - learned_round: str = "auto_round", + learned_round: str = "linear_round", learned_round_loss: str = "mse", optimizer: str = "sign_sgd", lr_scheduler: Optional[str] = "linear", @@ -178,7 +178,7 @@ def apply_learned_round( ) -> None: if learned_round not in LEARNED_ROUND_MAP: raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LEARNED_ROUND_MAP[learned_round]() + learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: raise ValueError(f"Learned round loss {learned_round_loss} is not available.") diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bb5a6cd8b..b0c36185f 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -684,7 +684,7 @@ def parse_args(args): parser.add_argument( '--learned-round', default=None, - choices=[None, 'auto_round'], + choices=[None, 'linear_round'], help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') return parser.parse_args(args) diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index 41b74b4d6..2701e78a3 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -22,16 +22,14 @@ LearnedRoundSigmoid(), # Sigmoid Implementation LearnedRoundSigmoid(learned_round_temperature=2.), # Sigmoid + Temperature LearnedRoundHardSigmoid(), # Hard Sigmoid - LearnedRoundIdentity(), # AutoRound Implement -] + LearnedRoundIdentity(),] class TestLearnedRound(): def instantiate_learnedround_float_to_int_impl(self, impl, weights, value): impl = LearnedRoundSte(impl, torch.full(weights.shape, 0.)) - # For methods with p_value=False, it is required that value is within [-0.5, 0.5] - if not impl.learned_round_impl.is_p_value: + if isinstance(impl.learned_round_impl, LearnedRoundIdentity): min_value, max_value = torch.min(value), torch.max(value) # Prevent division by zero when all the elements of the tensor are the same if max_value - min_value < 1e-8: @@ -61,8 +59,7 @@ def test_learnedround(self, impl, training, weights_value): out = impl(weights) # The FP values and its quantized values must differ by at most +/- 1 assert torch.all(torch.abs(out - weights) <= 1) - # For is_p_value=True, the rounding can be soft while training=True - if impl.learned_round_impl.is_p_value: + if not isinstance(impl.learned_round_impl, LearnedRoundIdentity): if training: # Soft quantization. All values are at most distant +/- 1 from the nearest integer assert torch.all(torch.abs(out - torch.round(out)) <= 1) @@ -70,7 +67,7 @@ def test_learnedround(self, impl, training, weights_value): # Hard quantization. All values are integers assert torch.allclose(out, torch.round(out)) else: - # All values should be integers when is_p_value=False + # All values should be integers for LearnedRoundIdentity assert torch.allclose(out, torch.round(out)) @given( @@ -87,8 +84,10 @@ def test_learnedround_float_to_int_impl_hard_sigmoid( learned_round_zeta=learned_round_zeta, learned_round_gamma=learned_round_gamma, ) - value_eval = learned_round_hard_sigmoid(value, training=False) - value_train = learned_round_hard_sigmoid(value, training=True) + learned_round_hard_sigmoid.train(False) + value_eval = learned_round_hard_sigmoid(value) + learned_round_hard_sigmoid.train(True) + value_train = learned_round_hard_sigmoid(value) out_eval = weight + value_eval out_train = weight + (value_train > 0.5) @@ -109,7 +108,7 @@ def learnedround_float_to_int_impl(self, impl): def test_learnedround_load_dict(self, learnedround_float_to_int_impl): config.IGNORE_MISSING_KEYS = True - impl, _ = learnedround_float_to_int_impl + impl, _, _ = learnedround_float_to_int_impl quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) try: diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index e9501668e..751d7790d 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -16,11 +16,11 @@ from brevitas import config from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import LearnedRoundImplType import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor.base_quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks from brevitas_examples.common.learned_round.learned_round_optimizer import save_inputs_output @@ -302,7 +302,11 @@ def model_forward(model, inputs): for cache_output, gt_output in zip(cache.output, fp_outs if disable_quant else quant_outs): _compare_tensors(cache_output, gt_output, disable_quant, keep_gpu) - @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) def test_insert_learned_round_quantizers(self, quant_model, learned_round): block = quant_model.in_proj_mlp learned_round.insert_learned_round_quantizers(block) @@ -313,7 +317,11 @@ def test_insert_learned_round_quantizers(self, quant_model, learned_round): assert isinstance( module.weight_quant.tensor_quant.int_quant.float_to_int_impl, LearnedRoundSte) - @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) @pytest.mark.parametrize( "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)]) From 2d1d9dffd94e8a417308a8d6434b7fbff60e0883 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 27 Nov 2024 12:06:19 +0000 Subject: [PATCH 24/48] Improve argument parsing learned round --- src/brevitas/graph/gpfq.py | 2 +- src/brevitas/graph/gptq.py | 2 +- src/brevitas/graph/gpxq.py | 4 - src/brevitas/utils/torch_utils.py | 5 + .../learned_round/learned_round_method.py | 5 - .../learned_round/learned_round_optimizer.py | 16 +--- .../learned_round/learned_round_parser.py | 96 +++++++++++++++++++ .../ptq/learned_round_utils.py | 52 +++------- src/brevitas_examples/llm/llm_quant/gpxq.py | 2 +- .../llm/llm_quant/learned_round_utils.py | 42 ++------ 10 files changed, 128 insertions(+), 98 deletions(-) create mode 100644 src/brevitas_examples/common/learned_round/learned_round_parser.py diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 92e3da2bf..b16faaf6b 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -24,11 +24,11 @@ from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP from brevitas.graph.gpxq import SUPPORTED_TCONV_OP import brevitas.nn as qnn from brevitas.quant_tensor import _unpack_quant_tensor +from brevitas.utils.torch_utils import StopFwdException class GPFQ(GPxQ): diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 667e47d40..6b1945dcc 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -18,9 +18,9 @@ from brevitas import torch_version from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn +from brevitas.utils.torch_utils import StopFwdException class GPTQ(GPxQ): diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index e71a273c3..2af0d7f2d 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -28,10 +28,6 @@ SUPPORTED_CONV_OP = (qnn.QuantConv1d, qnn.QuantConv2d, qnn.QuantConv3d, *SUPPORTED_TCONV_OP) -class StopFwdException(Exception): - pass - - @dataclass class LayerHandler: layer_names: Set = field(default_factory=set) diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 8942c513a..461c7bc92 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -11,6 +11,11 @@ from brevitas.function.ops_ste import floor_ste +class StopFwdException(Exception): + """Used to throw and catch an exception to stop traversing the graph.""" + pass + + class TupleSequential(Sequential): def output(self, mod, input): diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index ae6ab8392..6fb929136 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -15,11 +15,6 @@ from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -class StopFwdException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - pass - - class LearnedRoundLoss(ABC): @abstractmethod diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 0f2939acc..c8220d073 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -213,6 +213,7 @@ from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.optim.sign_sgd import SignSGD from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase +from brevitas.utils.torch_utils import StopFwdException from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.learned_round.learned_round_method import LearnedRound @@ -258,11 +259,6 @@ def _get_scale_parameters(module: nn.Module): return scale_parameters -class StopFwdException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - pass - - class Cache(ABC): @abstractmethod @@ -636,16 +632,8 @@ def _accelerate_optimize_learned_round_block( # Prepare dataset from cache cache_dataset = cache.cache_to_dataset() - # NOTE: Intuitively, the total samples retrieved during optimization should - # be self.batch_size*self.iters. However, a StopIteration is raised mid-training - # signaling that this is not correct. Should check why this is the case. - random_sampler = RandomSampler( - cache_dataset, replacement=True, num_samples=2 * self.batch_size * self.iters) cache_dataloader = DataLoader( - cache_dataset, - batch_size=self.batch_size, - sampler=random_sampler, - collate_fn=cache.collate_fn) + cache_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=cache.collate_fn) # Prepare elements for training cache_dataloader, block, optimizer, lr_scheduler = accelerator.prepare(cache_dataloader, block, optimizer, lr_scheduler) diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py new file mode 100644 index 000000000..dfc3d7a13 --- /dev/null +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +import warnings + +from accelerate.utils.operations import send_to_device +from datasets import Dataset +import torch +from torch import nn +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataloader import DataLoader + +from brevitas import config +from brevitas.inject.enum import LearnedRoundImplType +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL +from brevitas.optim.sign_sgd import SignSGD +from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss +from brevitas_examples.common.learned_round.learned_round_method import MSELoss +from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss +from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer + +LEARNED_ROUND_MAP = { + "linear_round": LearnedRoundImplType.IDENTITY, + "hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID, + "sigmoid_round": LearnedRoundImplType.SIGMOID,} +LEARNED_ROUND_LOSS_MAP = { + "mse": MSELoss, + "regularised_mse": RegularisedMSELoss,} +OPTIMIZER_MAP = { + "sign_sgd": SignSGD,} +LR_SCHEDULER_MAP = {} + + +def parse_learned_round(learned_round_str: str) -> LearnedRound: + if learned_round_str not in LEARNED_ROUND_MAP: + raise ValueError(f"Learned round method {learned_round_str} is not available.") + return LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round_str]) + + +def parse_learned_round_loss_class(learned_round_loss_str: str) -> Type[LearnedRoundLoss]: + if learned_round_loss_str not in LEARNED_ROUND_LOSS_MAP: + raise ValueError(f"Learned round loss {learned_round_loss_str} is not available.") + return LEARNED_ROUND_LOSS_MAP[learned_round_loss_str] + + +def parse_optimizer_class(optimizer_str: str) -> Type[Optimizer]: + if optimizer_str in OPTIMIZER_MAP: + optimizer_class = OPTIMIZER_MAP[optimizer_str] + else: + optimizer_keys = [ + optimizer_key for optimizer_key in torch.optim.__dict__.keys() if ( + optimizer_key.lower().startswith(optimizer_str.lower()) and + torch.optim.__dict__[optimizer_key] != Optimizer and + isinstance(torch.optim.__dict__[optimizer_key], type) and + issubclass(torch.optim.__dict__[optimizer_key], Optimizer))] + if len(optimizer_keys) == 0: + raise ValueError(f"{optimizer_str} is not a valid optimizer.") + else: + if len(optimizer_keys) > 1: + warnings.warn( + f"There are multiple potential matches for optimizer {optimizer_str}. " + f"Defaulting to {optimizer_keys[0]}") + optimizer_class = getattr(torch.optim, optimizer_keys[0]) + + return optimizer_class + + +def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]: + if lr_scheduler_str in LR_SCHEDULER_MAP: + lr_scheduler_class = LR_SCHEDULER_MAP[lr_scheduler_str] + else: + lr_scheduler_keys = [ + lr_scheduler_key for lr_scheduler_key in torch.optim.lr_scheduler.__dict__.keys() if ( + lr_scheduler_key.lower().startswith(lr_scheduler_str.lower()) and + torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and + isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and + issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))] + print(lr_scheduler_keys) + if len(lr_scheduler_keys) == 0: + warnings.warn( + f"There are no matches for LR scheduler {lr_scheduler_str}. " + f"No LR scheduler is going to be used.") + lr_scheduler_class = None + else: + if len(lr_scheduler_keys) > 1: + warnings.warn( + f"There are multiple potential matches for LR scheduler {lr_scheduler_str}." + f"Defaulting to {lr_scheduler_keys[0]}") + lr_scheduler_class = getattr(torch.optim.lr_scheduler, lr_scheduler_keys[0]) + + return lr_scheduler_class diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index aa943f473..499949b6e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -34,18 +34,17 @@ from datasets import Dataset import torch from torch import nn -from torch.optim.lr_scheduler import LinearLR from torch.utils.data.dataloader import DataLoader from brevitas import config -from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL -from brevitas.optim.sign_sgd import SignSGD from brevitas.quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import LearnedRound -from brevitas_examples.common.learned_round.learned_round_method import MSELoss -from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer +from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round +from brevitas_examples.common.learned_round.learned_round_parser import \ + parse_learned_round_loss_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_lr_scheduler_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class config.IGNORE_MISSING_KEYS = True @@ -58,24 +57,12 @@ def is_layer(module: nn.Module, module_name: str) -> bool: return isinstance(module, QuantWBIOL) -LEARNED_ROUND_MAP = { - "linear_round": LearnedRoundImplType.IDENTITY, - "hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID, - "sigmoid_round": LearnedRoundImplType.SIGMOID,} -LEARNED_ROUND_LOSS_MAP = { - "mse": MSELoss, - "regularised_mse": RegularisedMSELoss,} -OPTIMIZER_MAP = { - "adam": torch.optim.Adam, - "sign_sgd": SignSGD,} BLOCK_CHECK_MAP = { "layerwise": is_layer, "blockwise": is_resnet_block,} -LR_SCHEDULER_MAP = { - "linear": LinearLR,} -class CacheCNN(dict): +class CacheVision(dict): def __init__(self) -> None: super().__init__() @@ -168,26 +155,11 @@ def apply_learned_round( lr_scheduler_kwargs: Optional[Dict] = None, learned_round_mode: str = "layerwise", ) -> None: - if learned_round not in LEARNED_ROUND_MAP: - raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) - - if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: - raise ValueError(f"Learned round loss {learned_round_loss} is not available.") - learned_round_loss_class = LEARNED_ROUND_LOSS_MAP[learned_round_loss] - - if optimizer not in OPTIMIZER_MAP: - raise ValueError(f"Optimizer {optimizer} is not available.") - optimizer_class = OPTIMIZER_MAP[optimizer] - - if lr_scheduler is not None and lr_scheduler not in LR_SCHEDULER_MAP: - raise ValueError(f"Learning rate scheduler {lr_scheduler} is not available.") - lr_scheduler_class = None if lr_scheduler is None else LR_SCHEDULER_MAP[lr_scheduler] - - optimizer_classes = {"adam": torch.optim.Adam, "sign_sgd": SignSGD} - if optimizer not in optimizer_classes: - raise ValueError(f"{optimizer} is not a valid optimizer.") - optimizer_class = optimizer_classes[optimizer] + # Parse strings to obtain the arguments for the optimizer + learned_round = parse_learned_round(learned_round) + learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) + optimizer_class = parse_optimizer_class(optimizer) + lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) if learned_round_mode not in BLOCK_CHECK_MAP: learned_round_mode = "layerwise" @@ -213,7 +185,7 @@ def apply_learned_round( learned_round_loss_kwargs=learned_round_loss_kwargs, optimizer_kwargs=optimizer_kwargs, lr_scheduler_kwargs=lr_scheduler_kwargs) - cache = CacheCNN() + cache = CacheVision() learned_round_optimizer.apply_learned_round( model=model, model_forward=cnn_forward, diff --git a/src/brevitas_examples/llm/llm_quant/gpxq.py b/src/brevitas_examples/llm/llm_quant/gpxq.py index 5e61306d4..26e984cfd 100644 --- a/src/brevitas_examples/llm/llm_quant/gpxq.py +++ b/src/brevitas_examples/llm/llm_quant/gpxq.py @@ -15,8 +15,8 @@ from brevitas.graph.gpfq import GPFQv2 from brevitas.graph.gptq import GPTQ from brevitas.graph.gptq import gptq_mode -from brevitas.graph.gpxq import StopFwdException from brevitas.utils.python_utils import recurse_getattr +from brevitas.utils.torch_utils import StopFwdException from brevitas_examples.common.axe import A2GPFQ from brevitas_examples.common.axe import A2GPTQ diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index f402ce211..099b62a2e 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -7,28 +7,16 @@ from datasets import Dataset import torch from torch import nn -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer from torch.utils.data.dataloader import DataLoader from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer -from brevitas.inject.enum import LearnedRoundImplType -from brevitas.optim.sign_sgd import SignSGD -from brevitas_examples.common.learned_round.learned_round_method import LearnedRound -from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss -from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer - -LEARNED_ROUND_MAP = { - "linear_round": LearnedRoundImplType.IDENTITY,} -LEARNED_ROUND_LOSS_MAP = { - "mse": MSELoss,} -OPTIMIZER_MAP = { - "sign_sgd": SignSGD,} -LR_SCHEDULER_MAP = { - "linear": LinearLR,} +from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round +from brevitas_examples.common.learned_round.learned_round_parser import \ + parse_learned_round_loss_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_lr_scheduler_class +from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class class CacheLLM(dict): @@ -176,21 +164,11 @@ def apply_learned_round( lr_scheduler_kwargs: Optional[Dict] = None, learned_round_loss_kwargs: Optional[Dict] = None, ) -> None: - if learned_round not in LEARNED_ROUND_MAP: - raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) - - if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: - raise ValueError(f"Learned round loss {learned_round_loss} is not available.") - learned_round_loss_class = LEARNED_ROUND_LOSS_MAP[learned_round_loss] - - if optimizer not in OPTIMIZER_MAP: - raise ValueError(f"Optimizer {optimizer} is not available.") - optimizer_class = OPTIMIZER_MAP[optimizer] - - if lr_scheduler is not None and lr_scheduler not in LR_SCHEDULER_MAP: - raise ValueError(f"Learning rate scheduler {lr_scheduler} is not available.") - lr_scheduler_class = None if lr_scheduler is None else LR_SCHEDULER_MAP[lr_scheduler] + # Parse strings to obtain the arguments for the optimizer + learned_round = parse_learned_round(learned_round) + learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) + optimizer_class = parse_optimizer_class(optimizer) + lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) lr_scheduler_kwargs = { "start_factor": 1.0, From 7110e8b1c2f5cd75eb59318442f54c24b83687fc Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Wed, 27 Nov 2024 14:57:38 +0000 Subject: [PATCH 25/48] Adress comments --- .../learned_round/learned_round_optimizer.py | 12 +- .../learned_round/learned_round_parser.py | 1 - .../ptq/learned_round_utils.py | 14 +- .../llm/benchmark/llm_benchmark.py | 175 ------------------ .../llm/benchmark/post_processing.py | 32 ---- .../llm/llm_quant/learned_round_utils.py | 28 ++- src/brevitas_examples/llm/main.py | 1 + 7 files changed, 22 insertions(+), 241 deletions(-) delete mode 100644 src/brevitas_examples/llm/benchmark/llm_benchmark.py delete mode 100644 src/brevitas_examples/llm/benchmark/post_processing.py diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index c8220d073..ea068474e 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -203,7 +203,6 @@ from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataloader import RandomSampler from tqdm import tqdm from brevitas import config @@ -285,10 +284,6 @@ def initialize_cache(self) -> None: def clear_cache(self) -> None: pass - @abstractmethod - def reset_cache(self) -> None: - pass - @abstractmethod def cache_to_dataset(self) -> Dataset: pass @@ -699,7 +694,7 @@ def apply_learned_round( block_forward: Callable, data_loader: DataLoader, cache: Cache, - block_check_fn: Callable, + get_blocks_fn: Callable, model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True) -> None: @@ -711,7 +706,7 @@ def apply_learned_round( self.learned_round.insert_learned_round_quantizers(model) # Retrieve blocks using the appropiate function to check blocks - blocks = get_blocks(model, block_check_fn) + blocks = get_blocks_fn(model) print(f"Total Iterations per block {self.iters}") print(f"Number of blocks {len(blocks)}") @@ -726,7 +721,6 @@ def apply_learned_round( model = offload_model(model) # Cache needs to be cleared before populating it with the inputs and outputs # to the block under optimization. - cache.clear_cache() self._populate_cache( cache, model, @@ -801,7 +795,7 @@ def apply_learned_round( # TODO: This call might not be needed, check_clear and reset_cache methods # Reset cache after optimisation - cache.reset_cache() + cache.clear_cache() # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py index dfc3d7a13..c1e470331 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_parser.py +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -80,7 +80,6 @@ def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]: torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))] - print(lr_scheduler_keys) if len(lr_scheduler_keys) == 0: warnings.warn( f"There are no matches for LR scheduler {lr_scheduler_str}. " diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 499949b6e..8994bbc25 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,6 +26,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import functools import re from typing import Any, Callable, Dict, Optional, Tuple, Union import warnings @@ -39,6 +40,8 @@ from brevitas import config from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache +from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round from brevitas_examples.common.learned_round.learned_round_parser import \ @@ -62,7 +65,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool: "blockwise": is_resnet_block,} -class CacheVision(dict): +class CacheVision(Cache, dict): def __init__(self) -> None: super().__init__() @@ -97,12 +100,6 @@ def clear_cache(self) -> None: self["inputs"] = [] self["output"] = [] - def reset_cache(self) -> None: - del self["inputs"] - del self["output"] - self["inputs"] = [] - self["output"] = [] - def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: if isinstance(self["inputs"], list): self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim) @@ -166,6 +163,7 @@ def apply_learned_round( warnings.warn( f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") block_check_fn = BLOCK_CHECK_MAP[learned_round_mode] + get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn) lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, @@ -192,6 +190,6 @@ def apply_learned_round( block_forward=cnn_block_forward, data_loader=calibration_loader, cache=cache, - block_check_fn=block_check_fn, + get_blocks_fn=get_blocks_fn, keep_gpu=True, ) diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py deleted file mode 100644 index c21036be5..000000000 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import argparse -from functools import partial -from itertools import product -import os -from types import SimpleNamespace - -import pandas as pd -import torch.backends.cudnn as cudnn -import torch.nn.parallel -import torch.optim -import torch.utils.data -import torch.utils.data.distributed - -from brevitas import __version__ as brevitas_version -from brevitas import config -from brevitas import torch_version -from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index -# LLM example depends on optimum-amd, which requires PyTorch>=2.2 -from brevitas_examples.llm.main import main as main_llm -from brevitas_examples.llm.main import validate - -config.IGNORE_MISSING_KEYS = True - - -def parse_type(v, default_type): - if v == 'None': - return None - else: - return default_type(v) - - -def parse_bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y'): - return True - elif v.lower() in ('no', 'false', 'f', 'n'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -class hashabledict(dict): - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - -def unique(sequence): - seen = set() - return [x for x in sequence if not (x in seen or seen.add(x))] - - -LLM_PPL_MAP = { - 'facebook/opt-125m': None, - 'meta-llama/Llama-2-7b-hf': None,} - -OPTIONS_DEFAULT = { - 'model': list(LLM_PPL_MAP.keys()), # HF model name. Default: facebook/opt-125m. - 'seed': [0], # Seed for sampling the calibration data. Default: 0. - 'nsamples': [128], # Number of calibration data samples. Default: 128. - 'seqlen': [2048], # Sequence length. Default: 2048. - 'eval': [True], # Eval model PPL on the chosen Dataset. - 'dataset': ['wikitext2'], # Dataset to use for quantization (default: wikitext2) - 'gpxq_block_name': [None], # Block name for faster GPxQ optimization. Default: None - 'weight_bit_width': [8], # Weight bit width. Default: 8. - 'weight_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. - 'weight_scale_precision': ['float_scale' - ], # Whether scale is a float value or a po2. Default: po2. - 'weight_quant_type': ['sym'], # Weight quantization type. Default: asym. - 'weight_quant_format': ['int'], # Weight quantization type. Default: int. - 'weight_quant_granularity': [ - 'per_group'], # Granularity for scales/zero-point of weights. Default: per_group. - 'scale_rounding_func_type': [None], # Rounding function to use with Po2 scale. Default: None. - 'weight_group_dim': [ - None], # Override default group_dim for groupsize quantization. Default: layer-dependant - 'weight_group_size': [128], # Group size for per_group weight quantization. Default: 128. - 'quantize_weight_zero_point': [False], # Quantize weight zero-point. - 'input_bit_width': [None], # Input bit width. Default: None (disables input quantization). - 'input_quant_format': ['int'], # Input quantization type. Default: int. - 'input_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. - 'input_scale_precision': ['float_scale' - ], # Whether input scale is a float value or a po2. Default: float. - 'input_scale_type': ['static'], # Whether input scale is a static value or a dynamic value. - 'input_quant_type': ['asym'], # Input quantization type. Default: asym. - 'input_quant_granularity': [ - 'per_tensor'], # Granularity for scales/zero-point of inputs. Default: per_tensor. - 'input_group_size': [64], # Group size for per_group input quantization. Default: 64. - 'quantize_input_zero_point': [False], # Quantize input zero-point. - 'quantize_last_layer': [False], # Quantize last nn.Linear layer. - 'gptq': [False], # Apply GPTQ. - 'gpfq': [False], # Apply GPFQ. - 'gpxq_act_order': [False], # Apply GPxQ activation ordering. - 'gpxq_use_quant_activations': [False], # Use quantized activations in GPxQ. - 'gpxq_create_weight_orig': [False], # Create weight_orig in GPxQ. - 'gpxq_max_accumulator_bit_width': [None], # Maximum accumulator bit width for GPxQ using AXE. - 'gpxq_max_accumulator_tile_size': [None], # Maximum accumulator tile size for GPxQ using AXE. - 'act_calibration': [False], # Apply activation calibration. - 'bias_corr': [False], # Apply bias correction. - 'ln_affine_merge': [False], # Merge LN affine params. - 'no_quantize': [False], # Disable quantization. - 'no_float16': [False], # Disable float16 as base datatype and switch to float32. - 'replace_mha': [False], # Replace HuggingFace Attention with a quantizable version - 'weight_equalization': [ - False], # Apply weight equalization. Relevant to ReLU based models (e.g. OPT). - 'act_equalization': [None], # Apply activation equalization (SmoothQuant). - 'load_awq': [None], # Load the awq search results. - 'export_target': [None], # Model export. - 'export_prefix': [None], # Path prefix to use for the various export flows. - 'checkpoint_name': [None], # Filename to save checkpoint. - 'fuse_sequences': [False], # Whether to merge the dataset sequences. - 'learned_round': [None, - "linear_round"], # Whether to use learned round. If `None`, RTN is used. -} - -parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') -parser.add_argument('idx', type=int) -for option_name, option_value in OPTIONS_DEFAULT.items(): - if isinstance(option_value[0], bool): - type_args = parse_bool - else: - type_args = partial(parse_type, default_type=type(option_value[0])) - parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args) - - -def main(): - args = parser.parse_args() - - # Generate all possible configurations, including invalid ones - options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} - combinations = list(product(*options.values())) - configs = [] - for combination in combinations: - config_namespace = SimpleNamespace( - **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) - try: - validate(config_namespace) - configs.append(hashabledict(**config_namespace.__dict__)) - except AssertionError: - # Invalid configuration - pass - - configs = unique(configs) - - if args.idx > len(configs) - 1: - return - - config_namespace = SimpleNamespace(**configs[args.idx]) - args.gpu = get_gpu_index(args.idx) - print("Iter {}, GPU {}".format(args.idx, args.gpu)) - - try: - float_ppl, quant_ppl, _ = main_llm(config_namespace) - - # Results are saved in CSV - column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys() - ] + [ - 'FP perplexity', 'Quant perplexity', 'Torch version', 'Brevitas version'] - values = [v for _, v in config_namespace.__dict__.items()] + [ - float_ppl, quant_ppl, torch_version, brevitas_version] - llm_df = pd.DataFrame([values], columns=column_names) - - folder = './multirun/' + str(args.idx) - os.makedirs(folder, exist_ok=True) - llm_df.to_csv(os.path.join(folder, 'RESULTS_LLM.csv'), index=False) - - except Exception as E: - print("Exception at index {}: {}".format(args.idx, E)) - - -if __name__ == '__main__': - main() diff --git a/src/brevitas_examples/llm/benchmark/post_processing.py b/src/brevitas_examples/llm/benchmark/post_processing.py deleted file mode 100644 index ab33b15dd..000000000 --- a/src/brevitas_examples/llm/benchmark/post_processing.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -import pandas as pd - - -def main(): - main_dir = './multirun' - - evals = next(os.walk(main_dir))[1] - df = None - for eval in evals: - full_path = os.path.join(main_dir, eval, 'RESULTS_LLM.csv') - if not os.path.exists(full_path): - continue - if df is None: - df = pd.read_csv(full_path) - else: - single_df = pd.read_csv(full_path) - df = pd.concat([df, single_df]) - df = df.sort_values(by=list(df.columns)) - df.to_csv('RESULTS_LLM.csv', index=False, mode='w') - - grouped_df = df.groupby([ - 'Model', 'Weight bit width', 'Weight quant granularity', 'Learned round']) - idx = grouped_df['Quant perplexity'].transform(max) == df['Quant perplexity'] - best_config_df = df[idx] - best_config_df = best_config_df.sort_values(by=['Model', 'Quant perplexity']) - best_config_df.to_csv('RESULTS_LLM_BEST_CONFIGS.csv', index=False, mode='w') - - -if __name__ == '__main__': - main() diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 099b62a2e..6dad361ab 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -1,6 +1,7 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import functools from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from accelerate.utils.operations import send_to_device @@ -11,6 +12,8 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer +from brevitas.utils.python_utils import recurse_getattr +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round from brevitas_examples.common.learned_round.learned_round_parser import \ @@ -19,16 +22,14 @@ from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class -class CacheLLM(dict): +class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() - self.store_kwargs = True def store_inputs(self, args, kwargs) -> None: self["args"].append(args) - if self.store_kwargs: - self["kwargs"].append(kwargs) + self["kwargs"].append(kwargs) def store_output(self, output) -> None: if isinstance(output, (tuple, list)): @@ -41,17 +42,9 @@ def initialize_cache(self) -> None: self["output"] = [] def clear_cache(self) -> None: - del self["args"] - del self["output"] - self["args"] = [] - self["output"] = [] - self.store_kwargs = len(self["kwargs"]) == 0 - - def reset_cache(self) -> None: del self["args"] del self["kwargs"] del self["output"] - self.store_kwargs = True self["args"] = [] self["kwargs"] = [] self["output"] = [] @@ -141,8 +134,8 @@ def llm_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: return out -def llm_block_check_fn(module: nn.Module, module_name: str) -> bool: - return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) +def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: + return recurse_getattr(model, block_name_attribute) def apply_learned_round( @@ -151,8 +144,8 @@ def apply_learned_round( iters: int = 200, learned_round: str = "linear_round", learned_round_loss: str = "mse", + block_name_attribute: str = "layers", optimizer: str = "sign_sgd", - lr_scheduler: Optional[str] = "linear", optimizer_lr: float = 5e-3, batch_size: int = 8, learn_scale: bool = False, @@ -160,6 +153,7 @@ def apply_learned_round( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, learned_round_loss_kwargs: Optional[Dict] = None, @@ -170,6 +164,8 @@ def apply_learned_round( optimizer_class = parse_optimizer_class(optimizer) lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) + llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute) + lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, @@ -197,7 +193,7 @@ def apply_learned_round( block_forward=llm_block_forward, data_loader=calibration_loader, cache=cache, - block_check_fn=llm_block_check_fn, + get_blocks_fn=llm_block_check_fn, model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b0c36185f..22eabed6c 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -375,6 +375,7 @@ def main(args): model, calibration_loader, iters=args.learned_round_iters, + block_name_attribute=args.gpxq_block_name, learn_scale=args.learned_round_scale, ) print("Learned round applied.") From 30cae345c21395383ebb3889c16ac3c2663078a6 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 28 Nov 2024 18:15:52 +0000 Subject: [PATCH 26/48] Minor changes --- .../learned_round/learned_round_optimizer.py | 125 +----------------- .../ptq/learned_round_utils.py | 8 +- .../llm/llm_quant/learned_round_utils.py | 2 + src/brevitas_examples/llm/main.py | 12 ++ 4 files changed, 22 insertions(+), 125 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index ea068474e..5db0bdd83 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -364,6 +364,7 @@ def __init__( scale_optimizer_class: Type[Optimizer] = SGD, lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, learn_scale: bool = False, @@ -371,7 +372,6 @@ def __init__( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000., - use_accelerate: bool = False, learned_round_loss_kwargs: Optional[Dict] = None, optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, @@ -381,6 +381,7 @@ def __init__( self.scale_optimizer_class = scale_optimizer_class self.lr_scheduler_class = lr_scheduler_class self.optimizer_lr = optimizer_lr + self.optimizer_scale_lr = optimizer_scale_lr self.batch_size = batch_size self.iters = iters self.learn_scale = learn_scale @@ -396,10 +397,6 @@ def __init__( self.learned_round_loss_init = partial( learned_round_loss_class, **learned_round_loss_kwargs) - # TODO: Remove once validated and expose the flag - # self.use_accelerate = use_accelerate - self.use_accelerate = False - @torch.no_grad() def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: for n, m in block.named_modules(): @@ -508,8 +505,7 @@ def _optimize_learned_round_block( if self.learn_scale and scale_params is not None: optimizer_scale = self.scale_optimizer_class( scale_params, - lr=self.optimizer_lr, - momentum=0.9, + lr=self.optimizer_scale_lr, **self.optimizer_kwargs, ) lr_scheduler_scale = ( @@ -579,114 +575,6 @@ def _optimize_learned_round_block( return init_loss, best_loss, last_best_iter - # TODO: Enable saving best parameters - def _accelerate_optimize_learned_round_block( - self, - block: nn.Module, - block_learned_round_modules: List[nn.Module], - cache: Cache, - block_loss: LearnedRoundLoss, - block_forward: Callable, - ) -> Tuple[float, float, int]: - # Enable running in mixed precision - TORCH_DTYPE_TO_PRECISION_TYPE_MAP = { - torch.float16: PrecisionType.FP16, - torch.bfloat16: PrecisionType.BF16,} - raise_warning_dtype = False - if not self.use_amp: - mixed_precision_type = None - else: - if self.amp_dtype not in TORCH_DTYPE_TO_PRECISION_TYPE_MAP: - raise_warning_dtype = True - mixed_precision_type = None - else: - mixed_precision_type = TORCH_DTYPE_TO_PRECISION_TYPE_MAP[self.amp_dtype] - # Instantiate accelerator to run in a multi-GPU setting - accelerator = Accelerator(mixed_precision=mixed_precision_type) - - # Raise warning if the AMP dtype was defaulted to float32. This warning is raised after - # the instantiation of accelerator, to use its print functionality so the message is only - # printed once. - if raise_warning_dtype: - accelerator.print( - f"The dtype {self.amp_dtype} cannot be used for AMP training with accelerate. Defaulting to float32." - ) - - # Initilalize optimizer and LR scheduler - optimizer = self.optimizer_class( - itertools.chain( - *[ - block_learned_round_module.parameters() - for block_learned_round_module in block_learned_round_modules]), - lr=self.optimizer_lr, - **self.optimizer_kwargs, - ) - lr_scheduler = ( - self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) - if self.lr_scheduler_class else None) - - # Prepare dataset from cache - cache_dataset = cache.cache_to_dataset() - cache_dataloader = DataLoader( - cache_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=cache.collate_fn) - - # Prepare elements for training - cache_dataloader, block, optimizer, lr_scheduler = accelerator.prepare(cache_dataloader, block, optimizer, lr_scheduler) - - # Variables needed for printing - best_loss = torch.finfo(torch.float).max - init_loss = -1.0 - last_best_iter = self.iters - - # Initialize an iterator to extract elements from the cache dataloader - cache_iterator = iter(cache_dataloader) - - pbar = tqdm_accelerate(range(self.iters), desc='') - for i in pbar: - # Sample mini-batch from cache - inputs, fp_outs = next(cache_iterator) - - # Run block forward to obtain quant outputs - quant_outs = block_forward(block, inputs) - # Compute loss using the block loss function - loss, loss_components = block_loss(quant_outs, fp_outs) - - # Save best parameters before taking gradient step - curr_loss = loss.detach().cpu().item() - init_loss = curr_loss if i == 0 else init_loss - if loss < best_loss: - best_loss = curr_loss - last_best_iter = i + 1 - - # Scale loss and perform gradient step - # loss = loss * self.loss_scaling_factor - accelerator.backward(loss) - self._step(optimizer, lr_scheduler) - - # Update progress bar - pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) - - # Make sure no updates are received in the progress bar - pbar.close() - - # TODO: Include support for saving the best configuration during training - if not self.use_best_model: - # Override if the model with the lowest training error is not used - best_loss = curr_loss - last_best_iter = self.iters - - # TODO: Verify if this call is actually needed - # Wait for everyone before proceding to next block - accelerator.wait_for_everyone() - # Remove all the wrapper around the block - block = accelerator.unwrap_model(block) - # Clear memory - accelerator.free_memory() - # Move the block back to CPU - block.cpu() - - return init_loss, best_loss, last_best_iter - def apply_learned_round( self, model: nn.Module, @@ -764,11 +652,7 @@ def apply_learned_round( ) # Optimize block rounding - init_loss, best_loss, last_best_iter = ( - self._optimize_learned_round_block - if not self.use_accelerate - else self._accelerate_optimize_learned_round_block - )( + init_loss, best_loss, last_best_iter = self._optimize_learned_round_block( block=block, block_learned_round_modules=block_learned_round_modules, cache=cache, @@ -793,7 +677,6 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - # TODO: This call might not be needed, check_clear and reset_cache methods # Reset cache after optimisation cache.clear_cache() diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 8994bbc25..4c9b05ea8 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -120,14 +120,14 @@ def __len__(self): if isinstance(self["inputs"], list) else self["inputs"].shape[self.batch_dim]) -def cnn_forward(model: nn.Module, inputs: Any) -> None: +def vision_forward(model: nn.Module, inputs: Any) -> None: device = next(model.parameters()).device img, _ = inputs img = send_to_device(img, device) model(img) -def cnn_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: +def vision_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: device = next(block.parameters()).device inputs = send_to_device(inputs, device) return block(inputs) @@ -186,8 +186,8 @@ def apply_learned_round( cache = CacheVision() learned_round_optimizer.apply_learned_round( model=model, - model_forward=cnn_forward, - block_forward=cnn_block_forward, + model_forward=vision_forward, + block_forward=vision_block_forward, data_loader=calibration_loader, cache=cache, get_blocks_fn=get_blocks_fn, diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 6dad361ab..ad6d43693 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -147,6 +147,7 @@ def apply_learned_round( block_name_attribute: str = "layers", optimizer: str = "sign_sgd", optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, batch_size: int = 8, learn_scale: bool = False, use_best_model: bool = True, @@ -176,6 +177,7 @@ def apply_learned_round( optimizer_class=optimizer_class, lr_scheduler_class=lr_scheduler_class, optimizer_lr=optimizer_lr, + optimizer_scale_lr=optimizer_scale_lr, batch_size=batch_size, iters=iters, learn_scale=learn_scale, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 22eabed6c..46d974d23 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -376,6 +376,8 @@ def main(args): calibration_loader, iters=args.learned_round_iters, block_name_attribute=args.gpxq_block_name, + optimizer_lr=args.learned_round_lr, + optimizer_scale_lr=args.learned_round_scale_lr, learn_scale=args.learned_round_scale, ) print("Learned round applied.") @@ -566,6 +568,16 @@ def parse_args(args): type=int, default=64, help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--learned-round-lr', + type=float, + default=5e-3, + help='Learning rate for learned round parameter optimization. Default: %(default)s') + parser.add_argument( + '--learned-round-scale-lr', + type=float, + default=5e-3, + help='Learning rate for scale optimization during round learning. Default: %(default)s') parser.add_argument( '--learned-round-iters', type=int, From 78254501ca5bb4dd3df4f8fd8ad90a20f2ee2f55 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 29 Nov 2024 09:31:55 +0000 Subject: [PATCH 27/48] Enable passing block name in vision entrypoint --- .../ptq/learned_round_utils.py | 18 ++++++++++++------ .../ptq/ptq_evaluate.py | 6 ++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 4c9b05ea8..44abce8f4 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -52,8 +52,8 @@ config.IGNORE_MISSING_KEYS = True -def is_resnet_block(module: nn.Module, module_name: str) -> bool: - return (re.search(r"layer\d+", module_name) is not None) +def is_block(module: nn.Module, module_name: str, reg_exp: str = r"layer\d+") -> bool: + return (re.search(reg_exp, module_name) is not None) def is_layer(module: nn.Module, module_name: str) -> bool: @@ -62,7 +62,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool: BLOCK_CHECK_MAP = { "layerwise": is_layer, - "blockwise": is_resnet_block,} + "blockwise": is_block,} class CacheVision(Cache, dict): @@ -139,6 +139,7 @@ def apply_learned_round( iters: int = 1000, learned_round: str = "hard_sigmoid_round", learned_round_loss: str = "regularised_mse", + block_name_attribute: str = r"layer\d+", optimizer: str = "adam", lr_scheduler: Optional[str] = None, optimizer_lr: float = 1e-3, @@ -158,12 +159,17 @@ def apply_learned_round( optimizer_class = parse_optimizer_class(optimizer) lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) - if learned_round_mode not in BLOCK_CHECK_MAP: - learned_round_mode = "layerwise" + # Parse method to retrieve de model blocks + if learned_round_mode == "layerwise": + block_check_fn = is_layer + elif learned_round_mode == "blockwise": + block_check_fn = functools.partial(is_block, reg_exp=block_name_attribute) + else: + block_check_fn = is_layer warnings.warn( f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") - block_check_fn = BLOCK_CHECK_MAP[learned_round_mode] get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn) + lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 348213bdb..ff221e5ba 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -164,6 +164,11 @@ def validate_args(args): type=str, choices=[None, 'linear_round', 'hard_sigmoid_round', 'sigmoid_round'], help='Learned round type (default: None)') +parser.add_argument( + '--learned-round-block-name', + type=str, + default="layer\d+", + help='Block name for learned round. It works only if FX is not needed (default: %(default)s)') parser.add_argument( '--learned-round-loss', default='regularised_mse', @@ -516,6 +521,7 @@ def main(): iters=args.learned_round_iters, learned_round=args.learned_round, learned_round_loss=args.learned_round_loss, + block_name_attribute=args.learned_round_block_name, optimizer=args.optimizer, lr_scheduler=args.learned_round_lr_scheduler, optimizer_lr=args.learned_round_lr, From d4f369c1d9b1d83915ba86b9ec4a67828abd2b4c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 29 Nov 2024 18:21:43 +0000 Subject: [PATCH 28/48] Update vision requirements --- requirements/requirements-vision.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements-vision.txt b/requirements/requirements-vision.txt index 0e74a8966..c86540c5c 100644 --- a/requirements/requirements-vision.txt +++ b/requirements/requirements-vision.txt @@ -1,2 +1,3 @@ +accelerate torchvision tqdm From 6558bd960b4fb23f1a90d6a71d46eb36a21a11e0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 29 Nov 2024 18:55:11 +0000 Subject: [PATCH 29/48] Remove cache to dataset methods --- .../learned_round/learned_round_optimizer.py | 12 ------- .../learned_round/learned_round_parser.py | 9 ------ .../ptq/learned_round_utils.py | 7 ---- .../llm/llm_quant/learned_round_utils.py | 32 ------------------- 4 files changed, 60 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 5db0bdd83..866474a3b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -190,11 +190,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import warnings -from accelerate import Accelerator -from accelerate.utils import tqdm as tqdm_accelerate -from accelerate.utils.dataclasses import PrecisionType from accelerate.utils.operations import send_to_device -from datasets import Dataset import torch from torch import autocast from torch import nn @@ -284,14 +280,6 @@ def initialize_cache(self) -> None: def clear_cache(self) -> None: pass - @abstractmethod - def cache_to_dataset(self) -> Dataset: - pass - - @abstractmethod - def collate_fn(self, batch: Any) -> Any: - pass - class DataSaverHook: diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py index c1e470331..0382b2b3a 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_parser.py +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -1,28 +1,19 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import re from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import warnings -from accelerate.utils.operations import send_to_device -from datasets import Dataset import torch -from torch import nn from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer -from torch.utils.data.dataloader import DataLoader -from brevitas import config from brevitas.inject.enum import LearnedRoundImplType -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.optim.sign_sgd import SignSGD -from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss -from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer LEARNED_ROUND_MAP = { "linear_round": LearnedRoundImplType.IDENTITY, diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 44abce8f4..3df861a93 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -32,7 +32,6 @@ import warnings from accelerate.utils.operations import send_to_device -from datasets import Dataset import torch from torch import nn from torch.utils.data.dataloader import DataLoader @@ -108,12 +107,6 @@ def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: return self["inputs"][indices], self["output"][indices] - def cache_to_dataset(self) -> Dataset: - raise NotImplementedError("This method is still not available for CNNs.") - - def collate_fn(self, batch: Any) -> Any: - raise NotImplementedError("This method is still not available for CNNs.") - def __len__(self): return ( len(self["inputs"]) diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index ad6d43693..fa3ca0048 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -5,12 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from accelerate.utils.operations import send_to_device -from datasets import Dataset import torch from torch import nn from torch.utils.data.dataloader import DataLoader -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from transformers.models.opt.modeling_opt import OPTDecoderLayer from brevitas.utils.python_utils import recurse_getattr from brevitas_examples.common.learned_round.learned_round_optimizer import Cache @@ -73,35 +70,6 @@ def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: outs = torch.cat([cache_outs[i] for i in indices], dim=0) return (args, kwargs), outs - def cache_to_dataset(self) -> Dataset: - inputs_list = list(zip(self["args"], self["kwargs"])) - return list(zip(inputs_list, self["output"])) - - def collate_fn(self, batch: Any) -> Any: - # Format of the dataset is ((args, kwargs), outs) - # See cache_to_dataset - inputs, outs = map(list, zip(*batch)) - args, kwargs_dict = map(list, zip(*inputs)) - # Positional arguments - args = tuple(torch.cat(arg_tensor, dim=0) for arg_tensor in zip(*args)) - # Keyword arguments - kwargs = {} - for curr_dict in kwargs_dict: - for key, value in curr_dict.items(): - if isinstance(value, torch.Tensor): - if key not in kwargs: - kwargs[key] = [] - kwargs[key].append(value) - else: - if key not in kwargs: - kwargs[key] = value - for key, value in kwargs.items(): - if isinstance(value, list) and len(value) > 0: - kwargs[key] = torch.cat(kwargs[key], dim=0) - # FP outputs - outs = torch.cat(outs, dim=0) - return ((args, kwargs), outs) - def __len__(self): return len(self["args"]) From 56ecaf489c218a2f7a3c7426bfde9cc3ec6c1b20 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 30 Nov 2024 13:00:00 +0100 Subject: [PATCH 30/48] Update sign_sgd.py --- src/brevitas/optim/sign_sgd.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index bd26b40d4..c92d145f4 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -45,11 +45,10 @@ import torch from torch import Tensor -from torch.optim.optimizer import _default_to_fused_or_foreach -from torch.optim.optimizer import _differentiable_doc -from torch.optim.optimizer import _foreach_doc -from torch.optim.optimizer import _fused_doc -from torch.optim.optimizer import _maximize_doc +try: + from torch.optim.optimizer import _default_to_fused_or_foreach +except: + _default_to_fused_or_foreach = None from torch.optim.optimizer import _use_grad_for_differentiable from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD @@ -212,7 +211,7 @@ def sign_sgd( if foreach is None and fused is None: # why must we be explicit about an if statement for torch.jit.is_scripting here? # because JIT can't handle Optionals nor fancy conditionals when scripting - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and _default_to_fused_or_foreach is not None: fused, foreach = _default_to_fused_or_foreach( params, differentiable=False, use_fused=False ) From 183b3aeca9fc9fbbf99e781de8d06c7257790c6e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 30 Nov 2024 13:02:00 +0100 Subject: [PATCH 31/48] Fix indentation --- src/brevitas/optim/sign_sgd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index c92d145f4..36374d6f9 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -46,9 +46,9 @@ import torch from torch import Tensor try: - from torch.optim.optimizer import _default_to_fused_or_foreach + from torch.optim.optimizer import _default_to_fused_or_foreach except: - _default_to_fused_or_foreach = None + _default_to_fused_or_foreach = None from torch.optim.optimizer import _use_grad_for_differentiable from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD From 42462a51988c9bf9b92adf4b224a2f0f61cb619d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 30 Nov 2024 13:07:10 +0100 Subject: [PATCH 32/48] Precommit fix --- src/brevitas/optim/sign_sgd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index 36374d6f9..0afef124b 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -45,10 +45,11 @@ import torch from torch import Tensor + try: from torch.optim.optimizer import _default_to_fused_or_foreach except: - _default_to_fused_or_foreach = None + _default_to_fused_or_foreach = None from torch.optim.optimizer import _use_grad_for_differentiable from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD From 2c0b306b2f3f35f7c56cf4a95b8430021dc6bc31 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 11:46:25 +0000 Subject: [PATCH 33/48] Remove references to LRScheduler for backwards compatibility --- .../learned_round/learned_round_optimizer.py | 7 ++-- .../learned_round/learned_round_parser.py | 34 ++++++++++++++----- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 866474a3b..30f8c14b9 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -195,7 +195,6 @@ from torch import autocast from torch import nn from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader @@ -350,7 +349,7 @@ def __init__( *, optimizer_class: Type[Optimizer] = SignSGD, scale_optimizer_class: Type[Optimizer] = SGD, - lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, + lr_scheduler_class: Optional[Type] = LinearLR, optimizer_lr: float = 5e-3, optimizer_scale_lr: float = 5e-3, batch_size: float = 8, @@ -405,12 +404,12 @@ def _optim_step(self, *optimizers: Optimizer) -> None: optimizer.step() optimizer.zero_grad() - def _lr_sched_step(self, *lr_schedulers: LRScheduler) -> None: + def _lr_sched_step(self, *lr_schedulers: Any) -> None: for lr_scheduler in lr_schedulers: if lr_scheduler: lr_scheduler.step() - def _step(self, optimizers: List[Optimizer], lr_schedulers: List[LRScheduler]) -> None: + def _step(self, optimizers: List[Optimizer], lr_schedulers: List[Any]) -> None: for optimizer in optimizers: if optimizer: optimizer.step() diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py index 0382b2b3a..d9ece26d2 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_parser.py +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -5,7 +5,6 @@ import warnings import torch -from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from brevitas.inject.enum import LearnedRoundImplType @@ -44,11 +43,22 @@ def parse_optimizer_class(optimizer_str: str) -> Type[Optimizer]: optimizer_class = OPTIMIZER_MAP[optimizer_str] else: optimizer_keys = [ - optimizer_key for optimizer_key in torch.optim.__dict__.keys() if ( + optimizer_key for optimizer_key in torch.optim.__dict__.keys() + # Check for making sure that only valid Optimizer implementations are + # retrived, when matching with the string passed by the user + if ( + # Verify that the key stars with the one passed by the user optimizer_key.lower().startswith(optimizer_str.lower()) and - torch.optim.__dict__[optimizer_key] != Optimizer and + # Verify that key corresponds to a class isinstance(torch.optim.__dict__[optimizer_key], type) and - issubclass(torch.optim.__dict__[optimizer_key], Optimizer))] + # Make sure the abstract class is not used + optimizer_key != "Optimizer" and + # An optimizer implements zero_grad and step. Check that this + # is the case for the class retrieved from torch.optim + hasattr(torch.optim.__dict__[optimizer_key], 'step') and + callable(torch.optim.__dict__[optimizer_key].step) and + hasattr(torch.optim.__dict__[optimizer_key], 'zero_grad') and + callable(torch.optim.__dict__[optimizer_key].zero_grad))] if len(optimizer_keys) == 0: raise ValueError(f"{optimizer_str} is not a valid optimizer.") else: @@ -61,16 +71,24 @@ def parse_optimizer_class(optimizer_str: str) -> Type[Optimizer]: return optimizer_class -def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]: +def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type: if lr_scheduler_str in LR_SCHEDULER_MAP: lr_scheduler_class = LR_SCHEDULER_MAP[lr_scheduler_str] else: lr_scheduler_keys = [ - lr_scheduler_key for lr_scheduler_key in torch.optim.lr_scheduler.__dict__.keys() if ( + lr_scheduler_key for lr_scheduler_key in torch.optim.lr_scheduler.__dict__.keys() + # Check for making sure that only valid LRScheduler implementations are + # retrived, when matching with the string passed by the user + if ( lr_scheduler_key.lower().startswith(lr_scheduler_str.lower()) and - torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and + # Verify that key corresponds to a class isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and - issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))] + # Make sure the abstract class is not retrieved + lr_scheduler_key != "LRScheduler" and + # A learning rate scheduler implements zero_grad and step. Check that this + # is the case for the class retrieved from torch.optim.lr_scheduler + hasattr(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], 'step') and + callable(torch.optim.lr_scheduler.__dict__[lr_scheduler_key].step))] if len(lr_scheduler_keys) == 0: warnings.warn( f"There are no matches for LR scheduler {lr_scheduler_str}. " From 14e15db2280f71e0bec17a4be14157cd98d8542a Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 14:00:01 +0000 Subject: [PATCH 34/48] Fix import failing tests --- src/brevitas/optim/sign_sgd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index 0afef124b..805e3ef6c 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -48,12 +48,12 @@ try: from torch.optim.optimizer import _default_to_fused_or_foreach + from torch.optim.optimizer import _use_grad_for_differentiable except: _default_to_fused_or_foreach = None -from torch.optim.optimizer import _use_grad_for_differentiable + _use_grad_for_differentiable = None from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD -from torch.utils._foreach_utils import _get_fused_kernels_supported_devices __all__ = ["SignSGD", "sign_sgd"] From 4e1fdabe85857a82492c11d930e19c40c753beae Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 14:08:45 +0000 Subject: [PATCH 35/48] Fix for PyTorch 1.11 --- src/brevitas/optim/sign_sgd.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index 805e3ef6c..1e1232cd8 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -48,10 +48,14 @@ try: from torch.optim.optimizer import _default_to_fused_or_foreach - from torch.optim.optimizer import _use_grad_for_differentiable except: _default_to_fused_or_foreach = None - _use_grad_for_differentiable = None +try: + from torch.optim.optimizer import _use_grad_for_differentiable +except: + # Ensure backward compatibility with PyTorch < 1.13.0 + _use_grad_for_differentiable = torch.no_grad + from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD From d748ce8676b425045cc045470c6d9cd73a9dd3e7 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 14:18:25 +0000 Subject: [PATCH 36/48] Remove depedency from SGD --- src/brevitas/optim/sign_sgd.py | 58 +++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/src/brevitas/optim/sign_sgd.py b/src/brevitas/optim/sign_sgd.py index 1e1232cd8..70090916a 100644 --- a/src/brevitas/optim/sign_sgd.py +++ b/src/brevitas/optim/sign_sgd.py @@ -48,8 +48,10 @@ try: from torch.optim.optimizer import _default_to_fused_or_foreach + from torch.utils._foreach_utils import _get_fused_kernels_supported_devices except: _default_to_fused_or_foreach = None + _get_fused_kernels_supported_devices = None try: from torch.optim.optimizer import _use_grad_for_differentiable except: @@ -57,12 +59,11 @@ _use_grad_for_differentiable = torch.no_grad from torch.optim.optimizer import Optimizer -from torch.optim.sgd import SGD __all__ = ["SignSGD", "sign_sgd"] -class SignSGD(SGD): +class SignSGD(Optimizer): """Implements signed stochastic gradient descent (optionally with momentum). .. math:: @@ -126,8 +127,14 @@ def __init__( differentiable: bool = False, fused: Optional[bool] = None, ): - super().__init__( - params=params, + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( lr=lr, momentum=momentum, dampening=dampening, @@ -138,6 +145,49 @@ def __init__( differentiable=differentiable, fused=fused, ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + if fused: + self._step_supports_amp_scaling = True + + fused_supported_devices = _get_fused_kernels_supported_devices() + if not all(p.device.type in fused_supported_devices and torch.is_floating_point(p) + for pg in self.param_groups + for p in pg["params"]): + raise RuntimeError( + "`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}.") + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + group.setdefault("fused", False) + + def _init_group(self, group, params, grads, momentum_buffer_list): + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params.append(p) + grads.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + if group["momentum"] != 0: + state = self.state[p] + momentum_buffer_list.append(state.get("momentum_buffer")) + + return has_sparse_grad @_use_grad_for_differentiable def step(self, closure=None): From 47e2f20db8703348c2ce2ac003de381e86c7acaf Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 14:55:30 +0000 Subject: [PATCH 37/48] Remove default values --- .../common/learned_round/learned_round_optimizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 30f8c14b9..d7763f957 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -194,9 +194,7 @@ import torch from torch import autocast from torch import nn -from torch.optim.lr_scheduler import LinearLR from torch.optim.optimizer import Optimizer -from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -205,7 +203,6 @@ from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor -from brevitas.optim.sign_sgd import SignSGD from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas.utils.torch_utils import StopFwdException from brevitas_examples.common.accelerate_utils.accelerate import offload_model @@ -346,10 +343,10 @@ def __init__( self, learned_round: LearnedRound, learned_round_loss_class: Type[LearnedRoundLoss], + optimizer_class: Type[Optimizer], *, - optimizer_class: Type[Optimizer] = SignSGD, - scale_optimizer_class: Type[Optimizer] = SGD, - lr_scheduler_class: Optional[Type] = LinearLR, + scale_optimizer_class: Optional[Type[Optimizer]] = None, + lr_scheduler_class: Optional[Type] = None, optimizer_lr: float = 5e-3, optimizer_scale_lr: float = 5e-3, batch_size: float = 8, @@ -363,6 +360,9 @@ def __init__( optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, ) -> None: + # Verify that an optimizer is passed for optimizing the scale if learn_scale=True + assert not (learn_scale and scale_optimizer_class is None), "An optimizer needs to be passed for the scale if learn_scale is set to True." + self.learned_round = learned_round self.optimizer_class = optimizer_class self.scale_optimizer_class = scale_optimizer_class From 5120cf98399f3eca67aacc58155c5e4c449b4e93 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 15:31:51 +0000 Subject: [PATCH 38/48] Remove unused imports --- src/brevitas/core/function_wrapper/learned_round.py | 7 ++++++- tests/brevitas/optim/test_sign_sgd.py | 9 --------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index cfb1cfa5c..cfc78d03f 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -11,6 +11,7 @@ import brevitas from brevitas import config +from brevitas.core.function_wrapper.ops_ste import TensorClampSte from brevitas.core.utils import SliceTensor from brevitas.function.ops_ste import floor_ste from brevitas.function.ops_ste import round_ste @@ -72,10 +73,14 @@ class LearnedRoundIdentity(brevitas.jit.ScriptModule): def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() + self.tensor_clamp = TensorClampSte() @brevitas.jit.script_method def forward(self, p: torch.Tensor) -> torch.Tensor: - return p + return self.tensor_clamp( + p, + min_val=torch.tensor(-0.5, device=p.device), + max_val=torch.tensor(+0.5, device=p.device)) @brevitas.jit.script_method def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index b1e61cc3e..ee8b0b336 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -42,18 +42,11 @@ from copy import deepcopy from itertools import product -import math -import sys -from typing import List, Union -import unittest -from hypothesis import given import pytest import pytest_cases -from pytest_cases import fixture import torch from torch.nn import Parameter -import torch.nn as nn from torch.optim.lr_scheduler import LinearLR from brevitas.optim.sign_sgd import SignSGD @@ -61,8 +54,6 @@ torch.manual_seed(SEED) -from torch.testing._internal.common_optimizers import OptimizerInput - REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) From 6c7e72d4b6fe0105c72ff606db0bee3f04db62cb Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 16:31:41 +0000 Subject: [PATCH 39/48] Remove import and allow test to fail --- tests/brevitas/optim/test_sign_sgd.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index ee8b0b336..f91590626 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -43,12 +43,14 @@ from copy import deepcopy from itertools import product +from packaging.version import parse import pytest import pytest_cases import torch from torch.nn import Parameter from torch.optim.lr_scheduler import LinearLR +from brevitas import torch_version from brevitas.optim.sign_sgd import SignSGD from tests.conftest import SEED @@ -94,12 +96,16 @@ def test_sign_sgd_single_update(self, device, dtype, lr): assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) - from torch.testing._internal.common_optimizers import optims - @device_dtype_parametrize @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half + if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): + pytest.xfail( + "PyTorch versions previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half" + ) + optim_cls = SignSGD weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) bias = Parameter(torch.randn((10), device=device, dtype=dtype)) From 8e41ce5c676e4b30da7f743ca1b6a40d0faff9a5 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 16:44:08 +0000 Subject: [PATCH 40/48] Account for change in optimizer interface --- tests/brevitas/optim/test_sign_sgd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index f91590626..7862626ce 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -53,6 +53,7 @@ from brevitas import torch_version from brevitas.optim.sign_sgd import SignSGD from tests.conftest import SEED +from tests.marker import requires_pt_ge torch.manual_seed(SEED) @@ -77,6 +78,7 @@ class TestOptimSignSGD: + @requires_pt_ge('2.1') @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) def test_sign_sgd_single_update(self, device, dtype, lr): @@ -96,6 +98,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr): assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) + @requires_pt_ge('2.1') @device_dtype_parametrize @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) @@ -134,6 +137,7 @@ def closure(): else: assert closure().item() < initial_value + @requires_pt_ge('2.1') @pytest.mark.skipif( torch.cuda.device_count() <= 1, reason="At least two GPUs are required for this test.") @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) From 5e4217f83117db4fbd7cf3ff59ba346d925e1950 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 17:22:28 +0000 Subject: [PATCH 41/48] Transformers version check for sdpa attention --- .../llm/llm_quant/prepare_for_quantize.py | 12 +++-- tests/brevitas_examples/test_llm.py | 48 +++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 2a71546e4..fe161883d 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,18 +1,24 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + import warnings +from packaging import version import torch +import transformers from transformers.models.opt.modeling_opt import OPTAttention -from transformers.models.opt.modeling_opt import OPTSdpaAttention from brevitas.graph import ModuleToModuleByClass from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention QUANTIZABLE_MHA_MAP = { OPTAttention: (QuantizableOPTAttention, { - 'batch_first': True}), - OPTSdpaAttention: (QuantizableOPTAttention, { 'batch_first': True}),} +if version.parse('4.46.0') >= version.parse(transformers.__version__): + from transformers.models.opt.modeling_opt import OPTSdpaAttention + QUANTIZABLE_MHA_MAP[OPTSdpaAttention] = (QuantizableOPTAttention, {'batch_first': True}) + def replace_mha_with_quantizable_layers(model, dtype): rewriters = [] diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index e59973b95..7f0153f13 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -525,3 +525,51 @@ def test_small_models_torch_export(caplog, torch_export_args): filepath = args.export_prefix + ".pt" torchscript_model = torch.jit.load(filepath) os.remove(filepath) + + +@pytest_cases.fixture( + ids=[ + "llama", + "mistral",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "learned_round": "linear_round", + "learned_round_iters": 1, + "gpxq_block_name": "model.layers", + "float_ppl": 33238.8984375 if transformers_version_ge('4.46.0') else 33238.8984375, + "quant_ppl": 33252.21484375 if transformers_version_ge('4.46.0') else 33252.21484375}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "learned_round": "linear_round", + "learned_round_iters": 1, + "gpxq_block_name": "model.layers", + "float_ppl": 31275.958984375 if transformers_version_ge('4.46.0') else 31274.05078125, + "quant_ppl": 31337.4921875 if transformers_version_ge('4.46.0') else 33139.23046875},]) +def learned_round_ppl_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = learned_round_ppl_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allexact(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allexact(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" From 8ae222247dcb330e9e4cfa15f7aaee98da0ee53f Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 17:34:37 +0000 Subject: [PATCH 42/48] Fix tests --- .../llm/llm_quant/prepare_for_quantize.py | 2 +- tests/brevitas_examples/test_llm.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index fe161883d..ee2f0b3e6 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -15,7 +15,7 @@ OPTAttention: (QuantizableOPTAttention, { 'batch_first': True}),} -if version.parse('4.46.0') >= version.parse(transformers.__version__): +if version.parse(transformers.__version__) >= version.parse('4.46.0'): from transformers.models.opt.modeling_opt import OPTSdpaAttention QUANTIZABLE_MHA_MAP[OPTSdpaAttention] = (QuantizableOPTAttention, {'batch_first': True}) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 7f0153f13..1dbbdb5ad 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -540,8 +540,8 @@ def test_small_models_torch_export(caplog, torch_export_args): "learned_round": "linear_round", "learned_round_iters": 1, "gpxq_block_name": "model.layers", - "float_ppl": 33238.8984375 if transformers_version_ge('4.46.0') else 33238.8984375, - "quant_ppl": 33252.21484375 if transformers_version_ge('4.46.0') else 33252.21484375}, + "float_ppl": 33238.8984375, + "quant_ppl": 33252.21484375}, { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", "act_calibration": False, @@ -550,8 +550,8 @@ def test_small_models_torch_export(caplog, torch_export_args): "learned_round": "linear_round", "learned_round_iters": 1, "gpxq_block_name": "model.layers", - "float_ppl": 31275.958984375 if transformers_version_ge('4.46.0') else 31274.05078125, - "quant_ppl": 31337.4921875 if transformers_version_ge('4.46.0') else 33139.23046875},]) + "float_ppl": 31275.958984375, + "quant_ppl": 31337.4921875},]) def learned_round_ppl_args_and_ppl(default_run_args, request): args = default_run_args run_dict = request.param From eb498bcb32f1c2d59b52fcb90ca2edcfa76d1b2b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 2 Dec 2024 17:41:18 +0000 Subject: [PATCH 43/48] Relax test assertion --- tests/brevitas_examples/test_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 1dbbdb5ad..60dd33ac2 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -571,5 +571,5 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): float_ppl, quant_ppl, model = validate_args_and_run_main(args) float_ppl = float_ppl.detach().cpu().numpy() quant_ppl = quant_ppl.detach().cpu().numpy() - assert allexact(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" - assert allexact(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" From 692770072ef883dd279e99f6a7b30c9bbc453f96 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 2 Dec 2024 23:43:29 +0100 Subject: [PATCH 44/48] Update test_sign_sgd.py --- tests/brevitas/optim/test_sign_sgd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 7862626ce..10b5dfe5b 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -78,9 +78,9 @@ class TestOptimSignSGD: - @requires_pt_ge('2.1') @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) + @requires_pt_ge('2.1') def test_sign_sgd_single_update(self, device, dtype, lr): # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) @@ -98,10 +98,10 @@ def test_sign_sgd_single_update(self, device, dtype, lr): assert torch.allclose(weights, initial_weights - lr * weight_sign_grad) - @requires_pt_ge('2.1') @device_dtype_parametrize @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + @requires_pt_ge('2.1') def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): @@ -137,12 +137,12 @@ def closure(): else: assert closure().item() < initial_value - @requires_pt_ge('2.1') @pytest.mark.skipif( torch.cuda.device_count() <= 1, reason="At least two GPUs are required for this test.") @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) @pytest_cases.parametrize("dtype", [torch.float16, torch.float32]) + @requires_pt_ge('2.1') def test_forloop_goes_right_direction_multigpu( self, dtype, optimizer_kwargs, lr_scheduler_args): optim_cls = SignSGD From 632e39632efd85b0c8d556bb162286dcaa0e74d9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:19:52 +0100 Subject: [PATCH 45/48] Update test_sign_sgd.py --- tests/brevitas/optim/test_sign_sgd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 10b5dfe5b..5970ac262 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -51,7 +51,6 @@ from torch.optim.lr_scheduler import LinearLR from brevitas import torch_version -from brevitas.optim.sign_sgd import SignSGD from tests.conftest import SEED from tests.marker import requires_pt_ge @@ -80,8 +79,9 @@ class TestOptimSignSGD: @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) - @requires_pt_ge('2.1') + @requires_pt_ge('2.1') # TODO: revisit this def test_sign_sgd_single_update(self, device, dtype, lr): + from brevitas.optim.sign_sgd import SignSGD # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) # Initialize tensors to compute expected result @@ -103,6 +103,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr): @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) @requires_pt_ge('2.1') def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): pytest.xfail( @@ -145,6 +146,7 @@ def closure(): @requires_pt_ge('2.1') def test_forloop_goes_right_direction_multigpu( self, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD optim_cls = SignSGD # Learnable parameters weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) From 1a2863be9eb10860172bafeb2d37cca08ea82052 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:33:23 +0100 Subject: [PATCH 46/48] Update learned_round.py --- src/brevitas/core/function_wrapper/learned_round.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index cfc78d03f..07f5f5d0f 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -65,7 +65,8 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return floor_ste(x) + p -class LearnedRoundIdentity(brevitas.jit.ScriptModule): +# TODO: Restore JIT compatibility +class LearnedRoundIdentity(torch.nn.Module): """ Implementation for LearnedRound learned parameter Adapted from https://arxiv.org/abs/2309.05516 @@ -75,14 +76,12 @@ def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() - @brevitas.jit.script_method def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, min_val=torch.tensor(-0.5, device=p.device), max_val=torch.tensor(+0.5, device=p.device)) - @brevitas.jit.script_method def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p) From e5bc47c38b4ea581f68e6b405bc22faa6dcedcaa Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:43:25 +0100 Subject: [PATCH 47/48] Update learned_round.py --- src/brevitas/core/function_wrapper/learned_round.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 07f5f5d0f..ffc69b7da 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -76,12 +76,14 @@ def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() + @brevitas.jit.ignore def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, min_val=torch.tensor(-0.5, device=p.device), max_val=torch.tensor(+0.5, device=p.device)) + @brevitas.jit.ignore def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p) From ba1344ff06d489b3f51b278415c8bcbcb422413f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:14:55 +0000 Subject: [PATCH 48/48] Fix tests --- src/brevitas/core/function_wrapper/learned_round.py | 10 ++++------ tests/brevitas/optim/test_sign_sgd.py | 4 +++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index ffc69b7da..8387c839c 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -65,8 +65,7 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return floor_ste(x) + p -# TODO: Restore JIT compatibility -class LearnedRoundIdentity(torch.nn.Module): +class LearnedRoundIdentity(brevitas.jit.ScriptModule): """ Implementation for LearnedRound learned parameter Adapted from https://arxiv.org/abs/2309.05516 @@ -75,15 +74,14 @@ class LearnedRoundIdentity(torch.nn.Module): def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() + self.upper_lower_bound = brevitas.jit.Attribute(0.5, float) - @brevitas.jit.ignore def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, - min_val=torch.tensor(-0.5, device=p.device), - max_val=torch.tensor(+0.5, device=p.device)) + min_val=torch.tensor(-self.upper_lower_bound).type_as(p), + max_val=torch.tensor(self.upper_lower_bound).type_as(p)) - @brevitas.jit.ignore def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 5970ac262..2aca86c82 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -79,9 +79,10 @@ class TestOptimSignSGD: @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) - @requires_pt_ge('2.1') # TODO: revisit this + @requires_pt_ge('2.1') # TODO: revisit this def test_sign_sgd_single_update(self, device, dtype, lr): from brevitas.optim.sign_sgd import SignSGD + # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) # Initialize tensors to compute expected result @@ -104,6 +105,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr): @requires_pt_ge('2.1') def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): from brevitas.optim.sign_sgd import SignSGD + # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): pytest.xfail(