diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index c8220d073..ea068474e 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -203,7 +203,6 @@ from torch.optim.optimizer import Optimizer from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataloader import RandomSampler from tqdm import tqdm from brevitas import config @@ -285,10 +284,6 @@ def initialize_cache(self) -> None: def clear_cache(self) -> None: pass - @abstractmethod - def reset_cache(self) -> None: - pass - @abstractmethod def cache_to_dataset(self) -> Dataset: pass @@ -699,7 +694,7 @@ def apply_learned_round( block_forward: Callable, data_loader: DataLoader, cache: Cache, - block_check_fn: Callable, + get_blocks_fn: Callable, model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True) -> None: @@ -711,7 +706,7 @@ def apply_learned_round( self.learned_round.insert_learned_round_quantizers(model) # Retrieve blocks using the appropiate function to check blocks - blocks = get_blocks(model, block_check_fn) + blocks = get_blocks_fn(model) print(f"Total Iterations per block {self.iters}") print(f"Number of blocks {len(blocks)}") @@ -726,7 +721,6 @@ def apply_learned_round( model = offload_model(model) # Cache needs to be cleared before populating it with the inputs and outputs # to the block under optimization. - cache.clear_cache() self._populate_cache( cache, model, @@ -801,7 +795,7 @@ def apply_learned_round( # TODO: This call might not be needed, check_clear and reset_cache methods # Reset cache after optimisation - cache.reset_cache() + cache.clear_cache() # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: diff --git a/src/brevitas_examples/common/learned_round/learned_round_parser.py b/src/brevitas_examples/common/learned_round/learned_round_parser.py index dfc3d7a13..c1e470331 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_parser.py +++ b/src/brevitas_examples/common/learned_round/learned_round_parser.py @@ -80,7 +80,6 @@ def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]: torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))] - print(lr_scheduler_keys) if len(lr_scheduler_keys) == 0: warnings.warn( f"There are no matches for LR scheduler {lr_scheduler_str}. " diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 499949b6e..8994bbc25 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -26,6 +26,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import functools import re from typing import Any, Callable, Dict, Optional, Tuple, Union import warnings @@ -39,6 +40,8 @@ from brevitas import config from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor import QuantTensor +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache +from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round from brevitas_examples.common.learned_round.learned_round_parser import \ @@ -62,7 +65,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool: "blockwise": is_resnet_block,} -class CacheVision(dict): +class CacheVision(Cache, dict): def __init__(self) -> None: super().__init__() @@ -97,12 +100,6 @@ def clear_cache(self) -> None: self["inputs"] = [] self["output"] = [] - def reset_cache(self) -> None: - del self["inputs"] - del self["output"] - self["inputs"] = [] - self["output"] = [] - def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]: if isinstance(self["inputs"], list): self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim) @@ -166,6 +163,7 @@ def apply_learned_round( warnings.warn( f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.") block_check_fn = BLOCK_CHECK_MAP[learned_round_mode] + get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn) lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, @@ -192,6 +190,6 @@ def apply_learned_round( block_forward=cnn_block_forward, data_loader=calibration_loader, cache=cache, - block_check_fn=block_check_fn, + get_blocks_fn=get_blocks_fn, keep_gpu=True, ) diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py deleted file mode 100644 index c21036be5..000000000 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import argparse -from functools import partial -from itertools import product -import os -from types import SimpleNamespace - -import pandas as pd -import torch.backends.cudnn as cudnn -import torch.nn.parallel -import torch.optim -import torch.utils.data -import torch.utils.data.distributed - -from brevitas import __version__ as brevitas_version -from brevitas import config -from brevitas import torch_version -from brevitas_examples.imagenet_classification.ptq.utils import get_gpu_index -# LLM example depends on optimum-amd, which requires PyTorch>=2.2 -from brevitas_examples.llm.main import main as main_llm -from brevitas_examples.llm.main import validate - -config.IGNORE_MISSING_KEYS = True - - -def parse_type(v, default_type): - if v == 'None': - return None - else: - return default_type(v) - - -def parse_bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y'): - return True - elif v.lower() in ('no', 'false', 'f', 'n'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -class hashabledict(dict): - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - -def unique(sequence): - seen = set() - return [x for x in sequence if not (x in seen or seen.add(x))] - - -LLM_PPL_MAP = { - 'facebook/opt-125m': None, - 'meta-llama/Llama-2-7b-hf': None,} - -OPTIONS_DEFAULT = { - 'model': list(LLM_PPL_MAP.keys()), # HF model name. Default: facebook/opt-125m. - 'seed': [0], # Seed for sampling the calibration data. Default: 0. - 'nsamples': [128], # Number of calibration data samples. Default: 128. - 'seqlen': [2048], # Sequence length. Default: 2048. - 'eval': [True], # Eval model PPL on the chosen Dataset. - 'dataset': ['wikitext2'], # Dataset to use for quantization (default: wikitext2) - 'gpxq_block_name': [None], # Block name for faster GPxQ optimization. Default: None - 'weight_bit_width': [8], # Weight bit width. Default: 8. - 'weight_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. - 'weight_scale_precision': ['float_scale' - ], # Whether scale is a float value or a po2. Default: po2. - 'weight_quant_type': ['sym'], # Weight quantization type. Default: asym. - 'weight_quant_format': ['int'], # Weight quantization type. Default: int. - 'weight_quant_granularity': [ - 'per_group'], # Granularity for scales/zero-point of weights. Default: per_group. - 'scale_rounding_func_type': [None], # Rounding function to use with Po2 scale. Default: None. - 'weight_group_dim': [ - None], # Override default group_dim for groupsize quantization. Default: layer-dependant - 'weight_group_size': [128], # Group size for per_group weight quantization. Default: 128. - 'quantize_weight_zero_point': [False], # Quantize weight zero-point. - 'input_bit_width': [None], # Input bit width. Default: None (disables input quantization). - 'input_quant_format': ['int'], # Input quantization type. Default: int. - 'input_param_method': ['stats'], # How scales/zero-point are determined. Default: stats. - 'input_scale_precision': ['float_scale' - ], # Whether input scale is a float value or a po2. Default: float. - 'input_scale_type': ['static'], # Whether input scale is a static value or a dynamic value. - 'input_quant_type': ['asym'], # Input quantization type. Default: asym. - 'input_quant_granularity': [ - 'per_tensor'], # Granularity for scales/zero-point of inputs. Default: per_tensor. - 'input_group_size': [64], # Group size for per_group input quantization. Default: 64. - 'quantize_input_zero_point': [False], # Quantize input zero-point. - 'quantize_last_layer': [False], # Quantize last nn.Linear layer. - 'gptq': [False], # Apply GPTQ. - 'gpfq': [False], # Apply GPFQ. - 'gpxq_act_order': [False], # Apply GPxQ activation ordering. - 'gpxq_use_quant_activations': [False], # Use quantized activations in GPxQ. - 'gpxq_create_weight_orig': [False], # Create weight_orig in GPxQ. - 'gpxq_max_accumulator_bit_width': [None], # Maximum accumulator bit width for GPxQ using AXE. - 'gpxq_max_accumulator_tile_size': [None], # Maximum accumulator tile size for GPxQ using AXE. - 'act_calibration': [False], # Apply activation calibration. - 'bias_corr': [False], # Apply bias correction. - 'ln_affine_merge': [False], # Merge LN affine params. - 'no_quantize': [False], # Disable quantization. - 'no_float16': [False], # Disable float16 as base datatype and switch to float32. - 'replace_mha': [False], # Replace HuggingFace Attention with a quantizable version - 'weight_equalization': [ - False], # Apply weight equalization. Relevant to ReLU based models (e.g. OPT). - 'act_equalization': [None], # Apply activation equalization (SmoothQuant). - 'load_awq': [None], # Load the awq search results. - 'export_target': [None], # Model export. - 'export_prefix': [None], # Path prefix to use for the various export flows. - 'checkpoint_name': [None], # Filename to save checkpoint. - 'fuse_sequences': [False], # Whether to merge the dataset sequences. - 'learned_round': [None, - "linear_round"], # Whether to use learned round. If `None`, RTN is used. -} - -parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') -parser.add_argument('idx', type=int) -for option_name, option_value in OPTIONS_DEFAULT.items(): - if isinstance(option_value[0], bool): - type_args = parse_bool - else: - type_args = partial(parse_type, default_type=type(option_value[0])) - parser.add_argument(f'--{option_name}', default=option_value, nargs="+", type=type_args) - - -def main(): - args = parser.parse_args() - - # Generate all possible configurations, including invalid ones - options = {k: getattr(args, k) for k, _ in OPTIONS_DEFAULT.items()} - combinations = list(product(*options.values())) - configs = [] - for combination in combinations: - config_namespace = SimpleNamespace( - **{k: v for k, v in zip(OPTIONS_DEFAULT.keys(), combination)}) - try: - validate(config_namespace) - configs.append(hashabledict(**config_namespace.__dict__)) - except AssertionError: - # Invalid configuration - pass - - configs = unique(configs) - - if args.idx > len(configs) - 1: - return - - config_namespace = SimpleNamespace(**configs[args.idx]) - args.gpu = get_gpu_index(args.idx) - print("Iter {}, GPU {}".format(args.idx, args.gpu)) - - try: - float_ppl, quant_ppl, _ = main_llm(config_namespace) - - # Results are saved in CSV - column_names = [k.replace('_', ' ').capitalize() for k in config_namespace.__dict__.keys() - ] + [ - 'FP perplexity', 'Quant perplexity', 'Torch version', 'Brevitas version'] - values = [v for _, v in config_namespace.__dict__.items()] + [ - float_ppl, quant_ppl, torch_version, brevitas_version] - llm_df = pd.DataFrame([values], columns=column_names) - - folder = './multirun/' + str(args.idx) - os.makedirs(folder, exist_ok=True) - llm_df.to_csv(os.path.join(folder, 'RESULTS_LLM.csv'), index=False) - - except Exception as E: - print("Exception at index {}: {}".format(args.idx, E)) - - -if __name__ == '__main__': - main() diff --git a/src/brevitas_examples/llm/benchmark/post_processing.py b/src/brevitas_examples/llm/benchmark/post_processing.py deleted file mode 100644 index ab33b15dd..000000000 --- a/src/brevitas_examples/llm/benchmark/post_processing.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -import pandas as pd - - -def main(): - main_dir = './multirun' - - evals = next(os.walk(main_dir))[1] - df = None - for eval in evals: - full_path = os.path.join(main_dir, eval, 'RESULTS_LLM.csv') - if not os.path.exists(full_path): - continue - if df is None: - df = pd.read_csv(full_path) - else: - single_df = pd.read_csv(full_path) - df = pd.concat([df, single_df]) - df = df.sort_values(by=list(df.columns)) - df.to_csv('RESULTS_LLM.csv', index=False, mode='w') - - grouped_df = df.groupby([ - 'Model', 'Weight bit width', 'Weight quant granularity', 'Learned round']) - idx = grouped_df['Quant perplexity'].transform(max) == df['Quant perplexity'] - best_config_df = df[idx] - best_config_df = best_config_df.sort_values(by=['Model', 'Quant perplexity']) - best_config_df.to_csv('RESULTS_LLM_BEST_CONFIGS.csv', index=False, mode='w') - - -if __name__ == '__main__': - main() diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 099b62a2e..6dad361ab 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -1,6 +1,7 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import functools from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from accelerate.utils.operations import send_to_device @@ -11,6 +12,8 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer +from brevitas.utils.python_utils import recurse_getattr +from brevitas_examples.common.learned_round.learned_round_optimizer import Cache from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round from brevitas_examples.common.learned_round.learned_round_parser import \ @@ -19,16 +22,14 @@ from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class -class CacheLLM(dict): +class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() - self.store_kwargs = True def store_inputs(self, args, kwargs) -> None: self["args"].append(args) - if self.store_kwargs: - self["kwargs"].append(kwargs) + self["kwargs"].append(kwargs) def store_output(self, output) -> None: if isinstance(output, (tuple, list)): @@ -41,17 +42,9 @@ def initialize_cache(self) -> None: self["output"] = [] def clear_cache(self) -> None: - del self["args"] - del self["output"] - self["args"] = [] - self["output"] = [] - self.store_kwargs = len(self["kwargs"]) == 0 - - def reset_cache(self) -> None: del self["args"] del self["kwargs"] del self["output"] - self.store_kwargs = True self["args"] = [] self["kwargs"] = [] self["output"] = [] @@ -141,8 +134,8 @@ def llm_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: return out -def llm_block_check_fn(module: nn.Module, module_name: str) -> bool: - return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) +def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: + return recurse_getattr(model, block_name_attribute) def apply_learned_round( @@ -151,8 +144,8 @@ def apply_learned_round( iters: int = 200, learned_round: str = "linear_round", learned_round_loss: str = "mse", + block_name_attribute: str = "layers", optimizer: str = "sign_sgd", - lr_scheduler: Optional[str] = "linear", optimizer_lr: float = 5e-3, batch_size: int = 8, learn_scale: bool = False, @@ -160,6 +153,7 @@ def apply_learned_round( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, learned_round_loss_kwargs: Optional[Dict] = None, @@ -170,6 +164,8 @@ def apply_learned_round( optimizer_class = parse_optimizer_class(optimizer) lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) + llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute) + lr_scheduler_kwargs = { "start_factor": 1.0, "end_factor": 0.0, @@ -197,7 +193,7 @@ def apply_learned_round( block_forward=llm_block_forward, data_loader=calibration_loader, cache=cache, - block_check_fn=llm_block_check_fn, + get_blocks_fn=llm_block_check_fn, model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 574d1280d..ec15fe1f2 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -375,6 +375,7 @@ def main(args): model, calibration_loader, iters=args.learned_round_iters, + block_name_attribute=args.gpxq_block_name, learn_scale=args.learned_round_scale, ) print("Learned round applied.")