Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (llm/learned_round): fast block update #1110

Merged
merged 12 commits into from
Dec 5, 2024
120 changes: 100 additions & 20 deletions src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ def apply_learned_round(
get_blocks_fn: Callable,
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True) -> None:
keep_gpu: bool = True,
partial_update: bool = False) -> None:

# Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs
model_dict = None if model_prepare_fn is None else model_prepare_fn(model)
Expand All @@ -602,26 +603,28 @@ def apply_learned_round(

# Initialize cache to store partial inputs and outputs for each block
cache.initialize_cache()

floating_point_datasets = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floating_point_datasets is no longer used after the changes, right?

# Iterate over blocks and optimise the rounding parameters within each of them
for block_idx, block in enumerate(blocks):
# Distribute the model across devices to run a forward pass to capture
# inputs/outputs to the given block
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)
# inputs/outputs to the given block#
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
if block_idx == 0 and partial_update:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
cache.clear_cache()
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)

# Retrieve scales
scale_params = return_scale_parameters(block)
Expand Down Expand Up @@ -678,9 +681,86 @@ def apply_learned_round(
# Move the block back to CPU
block.cpu()

# Reset cache after optimisation
cache.clear_cache()
if block_idx + 1 < len(blocks) and partial_update:
cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache)

# The original configuration of the model is restored after finishing the optimization
if model_finish_fn is not None:
model_finish_fn(model, model_dict)

def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache):

# We need to propagate two datasets, one is a floating point dataset to compute float out
# The second is a quantized dataset to create the quantized input of the next blocks

# First, we disable quantization
disable_quant_class = DisableEnableQuantization()
disable_quant_class.disable_act_quantization(block, False)
disable_quant_class.disable_param_quantization(block, False)
return_quant_tensor_state = disable_return_quant_tensor(block)

# If we don't have a floating_point_dataset, we retrieve it from the cache
# The idea is that the cache contains the input to the very first block, and there is nothing
# quantized before that. This is a moderately strong assumption
if len(floating_point_datasets) <= 0:
for i in range(len(cache)):
(args, kwargs), _ = cache.sample_batch([i])
floating_point_datasets.append((args, kwargs))

# Then, we compute the floating point output of the current block
next_float_input = []
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
with torch.no_grad():
for args, kwargs in floating_point_datasets:
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_float_input.append((out,))
# We use this new output to generate a new temporary dataloder for the next block
# and to update our floating_point_dataset
new_data_loader = []
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(len(cache)):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
(args, kwargs), _ = cache.sample_batch([i])
new_data_loader.append((next_float_input[i], kwargs))
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

_, fp_dataset_kwargs = floating_point_datasets[i]
floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs)

# Temporary cache
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
tmp_cache = copy.deepcopy(cache)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
tmp_cache.clear_cache()

# We compute the floating point output of the upcoming block
next_block.cuda()
save_inputs_output(
next_block,
block_forward,
next_block,
new_data_loader,
tmp_cache,
store_inputs=False,
store_output=True,
keep_gpu=False,
disable_quant=True,
)
next_block.cpu()

cache['output'] = tmp_cache['output']

# Re-enable quantization
disable_quant_class.enable_act_quantization(block, False)
disable_quant_class.enable_param_quantization(block, False)
restore_return_quant_tensor(block, return_quant_tensor_state)

# Finally (!), we compute the quantized input of the next block
block.eval()
block.cuda()
next_quant_input = []
with torch.no_grad():
for i in range(len(cache)):
(args, kwargs), _ = cache.sample_batch([i])
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_quant_input.append((out,))
cache['args'] = copy.deepcopy(next_quant_input)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
block.cpu()

return cache, floating_point_datasets
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@ def apply_learned_round(
model_prepare_fn=llm_learned_round_prepare_fn,
model_finish_fn=llm_learned_round_finish_fn,
keep_gpu=False,
)
partial_update=True)
Loading