From aad2021f6d3715c9e28ca03a87078da6d259b11b Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Wed, 30 Oct 2024 13:46:04 +0100 Subject: [PATCH 1/9] PTQ support in nemo CLI Signed-off-by: Jan Lasek --- nemo/collections/llm/__init__.py | 3 +- nemo/collections/llm/api.py | 113 ++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 3fe20173cba2..a355a963ae6a 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -206,7 +206,7 @@ try: import nemo_run as run - from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate + from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate from nemo.collections.llm.recipes import * # noqa __all__.extend( @@ -218,6 +218,7 @@ "validate", "finetune", "generate", + "ptq", ] ) except ImportError as error: diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 4f47f5c4bc73..3f2237ca5d23 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,6 +25,7 @@ from typing_extensions import Annotated import nemo.lightning as nl +from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform @@ -77,7 +78,7 @@ def train( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> train(model, data, trainer, tokenizer="data") + >>> llm.train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -179,7 +180,7 @@ def finetune( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()]) PosixPath('/path/to/log_dir') """ @@ -230,7 +231,7 @@ def validate( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> validate(model, data, trainer, tokenizer="data") + >>> llm.validate(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -249,6 +250,112 @@ def validate( return app_state.exp_dir +@run.cli.factory +@run.autoconvert +def default_quantization( + algorithm: str = "fp8", + awq_block_size: int = 128, + sq_alpha: float = 0.5, + enable_kv_cache: Optional[bool] = None, +) -> QuantizationConfig: + """Default quantization configuration.""" + return QuantizationConfig( + algorithm=algorithm, + awq_block_size=awq_block_size, + sq_alpha=sq_alpha, + enable_kv_cache=enable_kv_cache, + ) + + +@run.cli.factory +@run.autoconvert +def default_export( + path: str, + dtype: Union[str, int] = "bf16", + decoder_type: Optional[str] = None, + inference_tensor_parallel: int = 1, + inference_pipeline_parallel: int = 1, +) -> ExportConfig: + """Default export configuration.""" + return ExportConfig( + path=path, + dtype=dtype, + decoder_type=decoder_type, + inference_tensor_parallel=inference_tensor_parallel, + inference_pipeline_parallel=inference_pipeline_parallel, + ) + + +@run.cli.entrypoint(name="ptq", namespace="llm") +def ptq( + nemo_checkpoint: str, + # TODO: Maybe also create calibration_config for parallel and data settings? + calib_tp: int = 1, + calib_pp: int = 1, + dataset_size: int = 512, + batch_size: int = 64, + seq_len: int = 128, + quantization_config: QuantizationConfig = default_quantization(), + export_config: ExportConfig = default_export(None), +): + """ + Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs + calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. + + This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules + or direcly using TensorRT-LLM library. + + Args: + nemo_checkpoint (str): The path to model to be quantized. + calib_tp (int): Calibration tensor parallelism. + calib_pp (int): Calibration pipeline parallelism. + dataset_size (int): Number of samples to run calibration. + batch_size (int): Batch size for calibration. + seq_len (int): Length of calibration samples (in tokens). + quantization_config (QuantizationConfig): Configuration for quantization algorithm. + export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. + + Returns: + Path: The directory path where quantized model is saved. + + Example: + >>> from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig + >>> nemo_checkpoint = "/opt/checkpoints/LLAMA3-8B-fp16" + >>> quantization_config = QuantizationConfig(algorithm="fp8") + >>> export_config = ExportConfig(path="/opt/checkpoints/LLAMA3-8B-fp8") + >>> llm.ptq(nemo_checkpoint, quantization_config=quantization_config, export_config=export_config) + '/opt/checkpoints/LLAMA3-8B-fp8' + """ + if export_config.path is None: + raise ValueError("The export_config.path needs to be specified, got None.") + + from nemo.collections.llm import quantization + + quantizer = quantization.Quantizer(quantization_config, export_config) + model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) + + get_dataloader = quantization.create_data_iterator_getter( + model, + dataset=quantization_config.calibration_dataset, + seq_len=quantization_config.calibration_seq_len, + batch_size=quantization_config.calibration_batch_size, + calibration_size=quantization_config.calibration_dataset_size, + ) + + forward_loop = quantizer.create_megatron_forward_loop( + get_dataloader, + num_batches=dataset_size // batch_size, + seq_length=seq_len, + micro_batch_size=batch_size, + ) + + model = quantizer.quantize(model, forward_loop) + + quantizer.export(model) + + return export_config.path + + def get_trtllm_deployable( nemo_checkpoint, model_type, From 5bd8f683da7ae116a9912e4743bb08143a1a163c Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Wed, 30 Oct 2024 13:47:58 +0100 Subject: [PATCH 2/9] Naming engine vs checkpoint Signed-off-by: Jan Lasek --- nemo/collections/llm/quantization/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 15367cb25aba..0d1dc96ae7e3 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -73,7 +73,7 @@ class QuantizationConfig: @dataclass class ExportConfig: - """Inference configuration for the quantized TensorRT-LLM engine""" + """Inference configuration for the quantized TensorRT-LLM checkpoint.""" path: str dtype: Union[str, int] = "bf16" From 962fb695db0178604d812f82d3419e28d415188e Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Wed, 30 Oct 2024 15:55:10 +0100 Subject: [PATCH 3/9] Apply review remarks Signed-off-by: Jan Lasek --- nemo/collections/llm/api.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 3f2237ca5d23..ee05818d6ad3 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -257,6 +257,10 @@ def default_quantization( awq_block_size: int = 128, sq_alpha: float = 0.5, enable_kv_cache: Optional[bool] = None, + calibration_dataset: str = "cnn_dailymail", + calibration_dataset_size: int = 512, + calibration_batch_size: int = 64, + calibration_seq_len: int = 128, ) -> QuantizationConfig: """Default quantization configuration.""" return QuantizationConfig( @@ -264,6 +268,10 @@ def default_quantization( awq_block_size=awq_block_size, sq_alpha=sq_alpha, enable_kv_cache=enable_kv_cache, + calibration_dataset=calibration_dataset, + calibration_dataset_size=calibration_dataset_size, + calibration_batch_size=calibration_batch_size, + calibration_seq_len=calibration_seq_len, ) @@ -289,12 +297,8 @@ def default_export( @run.cli.entrypoint(name="ptq", namespace="llm") def ptq( nemo_checkpoint: str, - # TODO: Maybe also create calibration_config for parallel and data settings? calib_tp: int = 1, calib_pp: int = 1, - dataset_size: int = 512, - batch_size: int = 64, - seq_len: int = 128, quantization_config: QuantizationConfig = default_quantization(), export_config: ExportConfig = default_export(None), ): @@ -309,9 +313,6 @@ def ptq( nemo_checkpoint (str): The path to model to be quantized. calib_tp (int): Calibration tensor parallelism. calib_pp (int): Calibration pipeline parallelism. - dataset_size (int): Number of samples to run calibration. - batch_size (int): Batch size for calibration. - seq_len (int): Length of calibration samples (in tokens). quantization_config (QuantizationConfig): Configuration for quantization algorithm. export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. @@ -332,24 +333,10 @@ def ptq( from nemo.collections.llm import quantization quantizer = quantization.Quantizer(quantization_config, export_config) - model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) - - get_dataloader = quantization.create_data_iterator_getter( - model, - dataset=quantization_config.calibration_dataset, - seq_len=quantization_config.calibration_seq_len, - batch_size=quantization_config.calibration_batch_size, - calibration_size=quantization_config.calibration_dataset_size, - ) - forward_loop = quantizer.create_megatron_forward_loop( - get_dataloader, - num_batches=dataset_size // batch_size, - seq_length=seq_len, - micro_batch_size=batch_size, - ) + model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) - model = quantizer.quantize(model, forward_loop) + model = quantizer.quantize(model) quantizer.export(model) From 88bee68ad686155653ceefa119e9a907509d07e6 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Nov 2024 12:31:02 +0100 Subject: [PATCH 4/9] Print log message with export path Signed-off-by: Jan Lasek --- nemo/collections/llm/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index ee05818d6ad3..b291eafb4d82 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -340,6 +340,9 @@ def ptq( quantizer.export(model) + console = Console() + console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]") + return export_config.path From f3d90c7b0eb116e2e72f5c0213f2d7d40d49b0f1 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Nov 2024 12:55:47 +0100 Subject: [PATCH 5/9] Use pathlib.Path for API consistency Signed-off-by: Jan Lasek --- nemo/collections/llm/api.py | 6 +++--- nemo/collections/llm/quantization/quantizer.py | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index b291eafb4d82..0af5a051f744 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -278,7 +278,7 @@ def default_quantization( @run.cli.factory @run.autoconvert def default_export( - path: str, + path: Union[Path, str] = "./qnemo", dtype: Union[str, int] = "bf16", decoder_type: Optional[str] = None, inference_tensor_parallel: int = 1, @@ -300,8 +300,8 @@ def ptq( calib_tp: int = 1, calib_pp: int = 1, quantization_config: QuantizationConfig = default_quantization(), - export_config: ExportConfig = default_export(None), -): + export_config: ExportConfig = default_export(), +) -> Path: """ Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 0d1dc96ae7e3..21e56afdf7b2 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -14,6 +14,7 @@ import os from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union import torch @@ -75,12 +76,15 @@ class QuantizationConfig: class ExportConfig: """Inference configuration for the quantized TensorRT-LLM checkpoint.""" - path: str + path: Union[Path, str] dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + def __post_init__(self): + self.path = Path(self.path) + def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: """Infers the modelopt decoder type from GPTConfig class""" From f28a839eac603924edae4c2dd725b88f066fb06d Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Nov 2024 13:23:18 +0100 Subject: [PATCH 6/9] Update PTQ docstring Signed-off-by: Jan Lasek --- nemo/collections/llm/api.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 0af5a051f744..0e332cb2b662 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -309,6 +309,16 @@ def ptq( This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules or direcly using TensorRT-LLM library. + The function can be used through the NeMo CLI in the following way: + + ```bash + # Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2 + nemo llm ptq nemo_checkpoint=/models/Llama-3-70B export_config.path=/models/Llama-3-70B-FP8 calib_tp=8 export_config.inference_tensor_parallel=2 + + # Choose different quantization method, for example, INT8 SmoothQuant + nemo llm ptq nemo_checkpoint=/models/Llama-3-8B export_config.path=/models/Llama-3-8B-INT8_SQ quantization_config.algorithm=int8_sq + ``` + Args: nemo_checkpoint (str): The path to model to be quantized. calib_tp (int): Calibration tensor parallelism. @@ -317,15 +327,7 @@ def ptq( export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. Returns: - Path: The directory path where quantized model is saved. - - Example: - >>> from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig - >>> nemo_checkpoint = "/opt/checkpoints/LLAMA3-8B-fp16" - >>> quantization_config = QuantizationConfig(algorithm="fp8") - >>> export_config = ExportConfig(path="/opt/checkpoints/LLAMA3-8B-fp8") - >>> llm.ptq(nemo_checkpoint, quantization_config=quantization_config, export_config=export_config) - '/opt/checkpoints/LLAMA3-8B-fp8' + Path: The path where the quantized checkpoint has been saved after calibration. """ if export_config.path is None: raise ValueError("The export_config.path needs to be specified, got None.") From 2c604bc8672df4582c6c54996be1df5d5be3797f Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Nov 2024 15:31:17 +0100 Subject: [PATCH 7/9] Fix pylint complains on too long lines Signed-off-by: Jan Lasek --- nemo/collections/llm/api.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 0e332cb2b662..f9b403989691 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -62,7 +62,8 @@ def train( resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. @@ -218,7 +219,8 @@ def validate( resume (Optional[AutoResume]): Resume from a checkpoint for validation. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns: @@ -313,10 +315,15 @@ def ptq( ```bash # Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2 - nemo llm ptq nemo_checkpoint=/models/Llama-3-70B export_config.path=/models/Llama-3-70B-FP8 calib_tp=8 export_config.inference_tensor_parallel=2 + nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \ + export_config.path=/models/Llama-3-70B-FP8 \ + calib_tp=8 \ + export_config.inference_tensor_parallel=2 # Choose different quantization method, for example, INT8 SmoothQuant - nemo llm ptq nemo_checkpoint=/models/Llama-3-8B export_config.path=/models/Llama-3-8B-INT8_SQ quantization_config.algorithm=int8_sq + nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \ + export_config.path=/models/Llama-3-8B-INT8_SQ \ + quantization_config.algorithm=int8_sq ``` Args: @@ -447,7 +454,8 @@ def deploy( if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") return - # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py + # Store triton ip, port and other args relevant for REST API + # in config.json to be accessible by rest_model_api.py store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response) triton_deployable = get_trtllm_deployable( From d93db283b4ea218f26f8748e17125dbbde05ce77 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Fri, 8 Nov 2024 12:18:32 +0100 Subject: [PATCH 8/9] Add missing docstrings Signed-off-by: Jan Lasek --- nemo/collections/llm/quantization/quantizer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 21e56afdf7b2..c53bd711b31c 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -87,7 +87,7 @@ def __post_init__(self): def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class""" + """Infers the modelopt decoder type from GPTConfig class.""" mapping = [ (llm.Baichuan2Config, "baichuan"), (llm.ChatGLMConfig, "chatglm"), @@ -111,17 +111,17 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: class Quantizer: - """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. + """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. The process consist of several steps: 1. Loading a Nemo model from disk using appropriate parallelism strategy 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors - 3. Producing output directory + 3. Producing an output directory with a quantized checkpoint and a tokenizer The output directory produced is intended to be consumed by TensorRT-LLM toolbox - for efficient inference. This can be achieved using NeMo inference containers. + for efficient inference. This can be achieved using nemo.export.tensorrt_llm module. """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): @@ -233,6 +233,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None ): + """Create a forward loop for over a given data iterator.""" from megatron.core.pipeline_parallel.schedules import get_forward_backward_func forward_backward_func = get_forward_backward_func() @@ -264,6 +265,7 @@ def loop(model): return loop def export(self, model: llm.GPTModel) -> None: + """Export model to a TensorRT-LLM checkpoint.""" assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support megatron_amp_O2 @@ -295,7 +297,7 @@ def export(self, model: llm.GPTModel) -> None: def get_calib_data_iter( data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 ): - """Creates a sample data iterator for calibration""" + """Creates a sample data iterator for calibration.""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" @@ -315,6 +317,7 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + """Create a function that provides iterator over a given dataset.""" def _iterator(): CHARACTERS_PER_TOKEN = 4 From 2a08c9e1a84c657d54072beed209e13babb6fe63 Mon Sep 17 00:00:00 2001 From: janekl Date: Fri, 8 Nov 2024 11:20:44 +0000 Subject: [PATCH 9/9] Apply isort and black reformatting Signed-off-by: janekl --- nemo/collections/llm/quantization/quantizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index c53bd711b31c..eb7b89a7c133 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -318,6 +318,7 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): """Create a function that provides iterator over a given dataset.""" + def _iterator(): CHARACTERS_PER_TOKEN = 4