Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adress comments
Browse files Browse the repository at this point in the history
pablomlago committed Nov 27, 2024
1 parent f1a188b commit a9e9b52
Showing 7 changed files with 22 additions and 241 deletions.
Original file line number Diff line number Diff line change
@@ -203,7 +203,6 @@
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import RandomSampler
from tqdm import tqdm

from brevitas import config
@@ -285,10 +284,6 @@ def initialize_cache(self) -> None:
def clear_cache(self) -> None:
pass

@abstractmethod
def reset_cache(self) -> None:
pass

@abstractmethod
def cache_to_dataset(self) -> Dataset:
pass
@@ -699,7 +694,7 @@ def apply_learned_round(
block_forward: Callable,
data_loader: DataLoader,
cache: Cache,
block_check_fn: Callable,
get_blocks_fn: Callable,
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True) -> None:
@@ -711,7 +706,7 @@ def apply_learned_round(
self.learned_round.insert_learned_round_quantizers(model)

# Retrieve blocks using the appropiate function to check blocks
blocks = get_blocks(model, block_check_fn)
blocks = get_blocks_fn(model)

print(f"Total Iterations per block {self.iters}")
print(f"Number of blocks {len(blocks)}")
@@ -726,7 +721,6 @@ def apply_learned_round(
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
cache.clear_cache()
self._populate_cache(
cache,
model,
@@ -801,7 +795,7 @@ def apply_learned_round(

# TODO: This call might not be needed, check_clear and reset_cache methods
# Reset cache after optimisation
cache.reset_cache()
cache.clear_cache()

# The original configuration of the model is restored after finishing the optimization
if model_finish_fn is not None:
Original file line number Diff line number Diff line change
@@ -80,7 +80,6 @@ def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]:
torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and
isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and
issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))]
print(lr_scheduler_keys)
if len(lr_scheduler_keys) == 0:
warnings.warn(
f"There are no matches for LR scheduler {lr_scheduler_str}. "
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import functools
import re
from typing import Any, Callable, Dict, Optional, Tuple, Union
import warnings
@@ -39,6 +40,8 @@
from brevitas import config
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.common.learned_round.learned_round_optimizer import Cache
from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks
from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer
from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round
from brevitas_examples.common.learned_round.learned_round_parser import \
@@ -62,7 +65,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool:
"blockwise": is_resnet_block,}


class CacheVision(dict):
class CacheVision(Cache, dict):

def __init__(self) -> None:
super().__init__()
@@ -97,12 +100,6 @@ def clear_cache(self) -> None:
self["inputs"] = []
self["output"] = []

def reset_cache(self) -> None:
del self["inputs"]
del self["output"]
self["inputs"] = []
self["output"] = []

def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]:
if isinstance(self["inputs"], list):
self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim)
@@ -166,6 +163,7 @@ def apply_learned_round(
warnings.warn(
f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.")
block_check_fn = BLOCK_CHECK_MAP[learned_round_mode]
get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn)
lr_scheduler_kwargs = {
"start_factor": 1.0,
"end_factor": 0.0,
@@ -192,6 +190,6 @@ def apply_learned_round(
block_forward=cnn_block_forward,
data_loader=calibration_loader,
cache=cache,
block_check_fn=block_check_fn,
get_blocks_fn=get_blocks_fn,
keep_gpu=True,
)
175 changes: 0 additions & 175 deletions src/brevitas_examples/llm/benchmark/llm_benchmark.py

This file was deleted.

32 changes: 0 additions & 32 deletions src/brevitas_examples/llm/benchmark/post_processing.py

This file was deleted.

28 changes: 12 additions & 16 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from accelerate.utils.operations import send_to_device
@@ -11,6 +12,8 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.opt.modeling_opt import OPTDecoderLayer

from brevitas.utils.python_utils import recurse_getattr
from brevitas_examples.common.learned_round.learned_round_optimizer import Cache
from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer
from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round
from brevitas_examples.common.learned_round.learned_round_parser import \
@@ -19,16 +22,14 @@
from brevitas_examples.common.learned_round.learned_round_parser import parse_optimizer_class


class CacheLLM(dict):
class CacheLLM(Cache, dict):

def __init__(self) -> None:
super().__init__()
self.store_kwargs = True

def store_inputs(self, args, kwargs) -> None:
self["args"].append(args)
if self.store_kwargs:
self["kwargs"].append(kwargs)
self["kwargs"].append(kwargs)

def store_output(self, output) -> None:
if isinstance(output, (tuple, list)):
@@ -41,17 +42,9 @@ def initialize_cache(self) -> None:
self["output"] = []

def clear_cache(self) -> None:
del self["args"]
del self["output"]
self["args"] = []
self["output"] = []
self.store_kwargs = len(self["kwargs"]) == 0

def reset_cache(self) -> None:
del self["args"]
del self["kwargs"]
del self["output"]
self.store_kwargs = True
self["args"] = []
self["kwargs"] = []
self["output"] = []
@@ -141,8 +134,8 @@ def llm_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor:
return out


def llm_block_check_fn(module: nn.Module, module_name: str) -> bool:
return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer)
def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]:
return recurse_getattr(model, block_name_attribute)


def apply_learned_round(
@@ -151,15 +144,16 @@ def apply_learned_round(
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
lr_scheduler: Optional[str] = "linear",
optimizer_lr: float = 5e-3,
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
@@ -170,6 +164,8 @@ def apply_learned_round(
optimizer_class = parse_optimizer_class(optimizer)
lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler)

llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute)

lr_scheduler_kwargs = {
"start_factor": 1.0,
"end_factor": 0.0,
@@ -197,7 +193,7 @@ def apply_learned_round(
block_forward=llm_block_forward,
data_loader=calibration_loader,
cache=cache,
block_check_fn=llm_block_check_fn,
get_blocks_fn=llm_block_check_fn,
model_prepare_fn=llm_learned_round_prepare_fn,
model_finish_fn=llm_learned_round_finish_fn,
keep_gpu=False,
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
@@ -375,6 +375,7 @@ def main(args):
model,
calibration_loader,
iters=args.learned_round_iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
)
print("Learned round applied.")

0 comments on commit a9e9b52

Please sign in to comment.