Skip to content

Commit

Permalink
Feat (llm/learned_round): fast block update (#1110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 5, 2024
1 parent 2caee25 commit 72b7f66
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ def apply_learned_round(
get_blocks_fn: Callable,
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True) -> None:
keep_gpu: bool = True,
fast_update: bool = False) -> None:

# Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs
model_dict = None if model_prepare_fn is None else model_prepare_fn(model)
Expand All @@ -602,26 +603,27 @@ def apply_learned_round(

# Initialize cache to store partial inputs and outputs for each block
cache.initialize_cache()

# Iterate over blocks and optimise the rounding parameters within each of them
for block_idx, block in enumerate(blocks):
# Distribute the model across devices to run a forward pass to capture
# inputs/outputs to the given block
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)
if block_idx == 0 or not fast_update:
cache.clear_cache()
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)

# Retrieve scales
scale_params = return_scale_parameters(block)
Expand Down Expand Up @@ -678,9 +680,60 @@ def apply_learned_round(
# Move the block back to CPU
block.cpu()

# Reset cache after optimisation
cache.clear_cache()
if block_idx + 1 < len(blocks) and fast_update:
cache = self.skip_full_execution(block, blocks[block_idx + 1], block_forward, cache)

# The original configuration of the model is restored after finishing the optimization
if model_finish_fn is not None:
model_finish_fn(model, model_dict)

def skip_full_execution(self, block, next_block, block_forward, cache):

# We need to compute two inputs, one is a floating point one to compute float out
# The second is a quantized one to create the quantized input of the next blocks

# We use the cache output to generate a new temporary dataloder for the next block
tmp_data_loader = []
for i in range(len(cache)):
(args, kwargs), output = cache.sample_batch([i])

tmp_data_loader.append(((output,), kwargs))

# Temporary cache
tmp_cache = type(cache)()

# We compute the floating point output of the upcoming block
if torch.cuda.is_available():
next_block.cuda()
save_inputs_output(
next_block,
block_forward,
next_block,
tmp_data_loader,
tmp_cache,
store_inputs=False,
store_output=True,
keep_gpu=False,
disable_quant=True,
)
next_block.cpu()

cache['output'] = tmp_cache['output']

# Finally (!), we compute the quantized input of the next block
block.eval()
if torch.cuda.is_available():
block.cuda()
next_quant_input = []
pbar = tqdm(range(len(cache)), desc='', leave=False)
with torch.no_grad():
for i in pbar:
(args, kwargs), _ = cache.sample_batch([i])
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_quant_input.append((out,))
pbar.close()
cache['args'] = next_quant_input
block.cpu()

return cache
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CacheVision(Cache, dict):
def __init__(self) -> None:
super().__init__()
self.batch_dim = 0
self.initialize_cache()

def store_inputs(self, args, kwargs) -> None:
input_batch = args[0]
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]
[--learned-round {None,linear_round}]
[--learned-round-fast-update]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -196,5 +197,8 @@ options:
--learned-round {None,linear_round}
Whether to use learned round. If `None`, RTN is used
(default: None)
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)

```
41 changes: 21 additions & 20 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class CacheLLM(Cache, dict):

def __init__(self) -> None:
super().__init__()
self.initialize_cache()

def store_inputs(self, args, kwargs) -> None:
self["args"].append(args)
Expand Down Expand Up @@ -107,25 +108,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]:


def apply_learned_round(
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
) -> None:
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
fast_update: bool = False) -> None:
# Parse strings to obtain the arguments for the optimizer
learned_round = parse_learned_round(learned_round)
learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss)
Expand Down Expand Up @@ -166,4 +167,4 @@ def apply_learned_round(
model_prepare_fn=llm_learned_round_prepare_fn,
model_finish_fn=llm_learned_round_finish_fn,
keep_gpu=False,
)
fast_update=fast_update)
8 changes: 7 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ def main(args):
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum})
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")

model = offload_model(model)
Expand Down Expand Up @@ -705,6 +706,11 @@ def parse_args(args):
default=None,
choices=[None, 'linear_round'],
help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)')
parser.add_argument(
'--learned-round-fast-update',
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
return parser.parse_args(args)


Expand Down

0 comments on commit 72b7f66

Please sign in to comment.