Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 4, 2024
1 parent f0fc191 commit fbb9109
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def apply_learned_round(
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True,
partial_update: bool = False) -> None:
fast_update: bool = False) -> None:

# Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs
model_dict = None if model_prepare_fn is None else model_prepare_fn(model)
Expand All @@ -607,8 +607,8 @@ def apply_learned_round(
# Iterate over blocks and optimise the rounding parameters within each of them
for block_idx, block in enumerate(blocks):
# Distribute the model across devices to run a forward pass to capture
# inputs/outputs to the given block#
if block_idx == 0 and partial_update:
# inputs/outputs to the given block
if block_idx == 0 or not fast_update:
cache.clear_cache()
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
Expand Down Expand Up @@ -681,7 +681,7 @@ def apply_learned_round(
# Move the block back to CPU
block.cpu()

if block_idx + 1 < len(blocks) and partial_update:
if block_idx + 1 < len(blocks) and fast_update:
cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache)

# The original configuration of the model is restored after finishing the optimization
Expand All @@ -707,32 +707,21 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_
(args, kwargs), _ = cache.sample_batch([i])
floating_point_datasets.append((args, kwargs))

# Then, we compute the floating point output of the current block
next_float_input = []
block.cuda()
pbar = tqdm(floating_point_datasets, desc='', leave=False)
with torch.no_grad():
for args, kwargs in pbar:
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_float_input.append((out,))
pbar.close()
block.cpu()
# We use this new output to generate a new temporary dataloder for the next block
# and to update our floating_point_dataset
new_data_loader = []
for i in range(len(cache)):
(args, kwargs), _ = cache.sample_batch([i])
new_data_loader.append((next_float_input[i], kwargs))
(args, kwargs), output = cache.sample_batch([i])
new_data_loader.append(((output,), kwargs))

_, fp_dataset_kwargs = floating_point_datasets[i]
floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs)
floating_point_datasets[i] = ((output,), kwargs)

# Temporary cache
tmp_cache = type(cache)()

# We compute the floating point output of the upcoming block
next_block.cuda()
if torch.cuda.is_available():
next_block.cuda()
save_inputs_output(
next_block,
block_forward,
Expand Down
40 changes: 20 additions & 20 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]:


def apply_learned_round(
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
) -> None:
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
fast_update: bool = False) -> None:
# Parse strings to obtain the arguments for the optimizer
learned_round = parse_learned_round(learned_round)
learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss)
Expand Down Expand Up @@ -169,4 +169,4 @@ def apply_learned_round(
model_prepare_fn=llm_learned_round_prepare_fn,
model_finish_fn=llm_learned_round_finish_fn,
keep_gpu=False,
partial_update=True)
partial_update=fast_update)
8 changes: 7 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ def main(args):
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum})
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")

model = offload_model(model)
Expand Down Expand Up @@ -705,6 +706,11 @@ def parse_args(args):
default=None,
choices=[None, 'linear_round'],
help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)')
parser.add_argument(
'--learned-round-fast-update',
default=False,
type=bool,
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
return parser.parse_args(args)


Expand Down

0 comments on commit fbb9109

Please sign in to comment.