Skip to content

Commit

Permalink
Update flag and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 5, 2024
1 parent 23de059 commit 2a1e324
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -690,14 +690,8 @@ def apply_learned_round(

def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache):

# We need to propagate two datasets, one is a floating point dataset to compute float out
# The second is a quantized dataset to create the quantized input of the next blocks

# First, we disable quantization
disable_quant_class = DisableEnableQuantization()
disable_quant_class.disable_act_quantization(block, False)
disable_quant_class.disable_param_quantization(block, False)
return_quant_tensor_state = disable_return_quant_tensor(block)
# We need to compute two inputs, one is a floating point one to compute float out
# The second is a quantized one to create the quantized input of the next blocks

# If we don't have a floating_point_dataset, we retrieve it from the cache
# The idea is that the cache contains the input to the very first block, and there is nothing
Expand Down Expand Up @@ -737,11 +731,6 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_

cache['output'] = tmp_cache['output']

# Re-enable quantization
disable_quant_class.enable_act_quantization(block, False)
disable_quant_class.enable_param_quantization(block, False)
restore_return_quant_tensor(block, return_quant_tensor_state)

# Finally (!), we compute the quantized input of the next block
block.eval()
block.cuda()
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]
[--learned-round {None,linear_round}]
[--learned-round-fast-update]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -196,5 +197,8 @@ options:
--learned-round {None,linear_round}
Whether to use learned round. If `None`, RTN is used
(default: None)
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)

```
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def parse_args(args):
parser.add_argument(
'--learned-round-fast-update',
default=False,
type=bool,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
return parser.parse_args(args)

Expand Down

0 comments on commit 2a1e324

Please sign in to comment.