diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e03bfa9f1..3fe864a4d 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -690,14 +690,8 @@ def apply_learned_round( def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache): - # We need to propagate two datasets, one is a floating point dataset to compute float out - # The second is a quantized dataset to create the quantized input of the next blocks - - # First, we disable quantization - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(block, False) - disable_quant_class.disable_param_quantization(block, False) - return_quant_tensor_state = disable_return_quant_tensor(block) + # We need to compute two inputs, one is a floating point one to compute float out + # The second is a quantized one to create the quantized input of the next blocks # If we don't have a floating_point_dataset, we retrieve it from the cache # The idea is that the cache contains the input to the very first block, and there is nothing @@ -737,11 +731,6 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ cache['output'] = tmp_cache['output'] - # Re-enable quantization - disable_quant_class.enable_act_quantization(block, False) - disable_quant_class.enable_param_quantization(block, False) - restore_return_quant_tensor(block, return_quant_tensor_state) - # Finally (!), we compute the quantized input of the next block block.eval() block.cuda() diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index a97ad548b..64c82c80a 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] [--learned-round {None,linear_round}] + [--learned-round-fast-update] options: -h, --help show this help message and exit @@ -196,5 +197,8 @@ options: --learned-round {None,linear_round} Whether to use learned round. If `None`, RTN is used (default: None) + --learned-round-fast-update + Whether to use fast update with learned round. + Prototype (default: False) ``` diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 203a86e89..c8c76d4a1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -709,7 +709,7 @@ def parse_args(args): parser.add_argument( '--learned-round-fast-update', default=False, - type=bool, + action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') return parser.parse_args(args)