diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 2051f844d888..cedd2f975673 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -208,7 +208,7 @@ try: import nemo_run as run - from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate + from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate from nemo.collections.llm.recipes import * # noqa __all__.extend( @@ -220,6 +220,7 @@ "validate", "finetune", "generate", + "ptq", ] ) except ImportError as error: diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 4f47f5c4bc73..f9b403989691 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,6 +25,7 @@ from typing_extensions import Annotated import nemo.lightning as nl +from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform @@ -61,7 +62,8 @@ def train( resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. @@ -77,7 +79,7 @@ def train( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> train(model, data, trainer, tokenizer="data") + >>> llm.train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -179,7 +181,7 @@ def finetune( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()]) PosixPath('/path/to/log_dir') """ @@ -217,7 +219,8 @@ def validate( resume (Optional[AutoResume]): Resume from a checkpoint for validation. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns: @@ -230,7 +233,7 @@ def validate( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> validate(model, data, trainer, tokenizer="data") + >>> llm.validate(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -249,6 +252,109 @@ def validate( return app_state.exp_dir +@run.cli.factory +@run.autoconvert +def default_quantization( + algorithm: str = "fp8", + awq_block_size: int = 128, + sq_alpha: float = 0.5, + enable_kv_cache: Optional[bool] = None, + calibration_dataset: str = "cnn_dailymail", + calibration_dataset_size: int = 512, + calibration_batch_size: int = 64, + calibration_seq_len: int = 128, +) -> QuantizationConfig: + """Default quantization configuration.""" + return QuantizationConfig( + algorithm=algorithm, + awq_block_size=awq_block_size, + sq_alpha=sq_alpha, + enable_kv_cache=enable_kv_cache, + calibration_dataset=calibration_dataset, + calibration_dataset_size=calibration_dataset_size, + calibration_batch_size=calibration_batch_size, + calibration_seq_len=calibration_seq_len, + ) + + +@run.cli.factory +@run.autoconvert +def default_export( + path: Union[Path, str] = "./qnemo", + dtype: Union[str, int] = "bf16", + decoder_type: Optional[str] = None, + inference_tensor_parallel: int = 1, + inference_pipeline_parallel: int = 1, +) -> ExportConfig: + """Default export configuration.""" + return ExportConfig( + path=path, + dtype=dtype, + decoder_type=decoder_type, + inference_tensor_parallel=inference_tensor_parallel, + inference_pipeline_parallel=inference_pipeline_parallel, + ) + + +@run.cli.entrypoint(name="ptq", namespace="llm") +def ptq( + nemo_checkpoint: str, + calib_tp: int = 1, + calib_pp: int = 1, + quantization_config: QuantizationConfig = default_quantization(), + export_config: ExportConfig = default_export(), +) -> Path: + """ + Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs + calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. + + This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules + or direcly using TensorRT-LLM library. + + The function can be used through the NeMo CLI in the following way: + + ```bash + # Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2 + nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \ + export_config.path=/models/Llama-3-70B-FP8 \ + calib_tp=8 \ + export_config.inference_tensor_parallel=2 + + # Choose different quantization method, for example, INT8 SmoothQuant + nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \ + export_config.path=/models/Llama-3-8B-INT8_SQ \ + quantization_config.algorithm=int8_sq + ``` + + Args: + nemo_checkpoint (str): The path to model to be quantized. + calib_tp (int): Calibration tensor parallelism. + calib_pp (int): Calibration pipeline parallelism. + quantization_config (QuantizationConfig): Configuration for quantization algorithm. + export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. + + Returns: + Path: The path where the quantized checkpoint has been saved after calibration. + """ + if export_config.path is None: + raise ValueError("The export_config.path needs to be specified, got None.") + + from nemo.collections.llm import quantization + + quantizer = quantization.Quantizer(quantization_config, export_config) + + model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) + + model = quantizer.quantize(model) + + quantizer.export(model) + + console = Console() + console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]") + + return export_config.path + + def get_trtllm_deployable( nemo_checkpoint, model_type, @@ -348,7 +454,8 @@ def deploy( if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") return - # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py + # Store triton ip, port and other args relevant for REST API + # in config.json to be accessible by rest_model_api.py store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response) triton_deployable = get_trtllm_deployable( diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 15367cb25aba..eb7b89a7c133 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -14,6 +14,7 @@ import os from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union import torch @@ -73,17 +74,20 @@ class QuantizationConfig: @dataclass class ExportConfig: - """Inference configuration for the quantized TensorRT-LLM engine""" + """Inference configuration for the quantized TensorRT-LLM checkpoint.""" - path: str + path: Union[Path, str] dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + def __post_init__(self): + self.path = Path(self.path) + def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class""" + """Infers the modelopt decoder type from GPTConfig class.""" mapping = [ (llm.Baichuan2Config, "baichuan"), (llm.ChatGLMConfig, "chatglm"), @@ -107,17 +111,17 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: class Quantizer: - """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. + """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. The process consist of several steps: 1. Loading a Nemo model from disk using appropriate parallelism strategy 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors - 3. Producing output directory + 3. Producing an output directory with a quantized checkpoint and a tokenizer The output directory produced is intended to be consumed by TensorRT-LLM toolbox - for efficient inference. This can be achieved using NeMo inference containers. + for efficient inference. This can be achieved using nemo.export.tensorrt_llm module. """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): @@ -229,6 +233,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None ): + """Create a forward loop for over a given data iterator.""" from megatron.core.pipeline_parallel.schedules import get_forward_backward_func forward_backward_func = get_forward_backward_func() @@ -260,6 +265,7 @@ def loop(model): return loop def export(self, model: llm.GPTModel) -> None: + """Export model to a TensorRT-LLM checkpoint.""" assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support megatron_amp_O2 @@ -291,7 +297,7 @@ def export(self, model: llm.GPTModel) -> None: def get_calib_data_iter( data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 ): - """Creates a sample data iterator for calibration""" + """Creates a sample data iterator for calibration.""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" @@ -311,6 +317,8 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): + """Create a function that provides iterator over a given dataset.""" + def _iterator(): CHARACTERS_PER_TOKEN = 4