Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTQ via NeMo-Run CLI #10984

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
try:
import nemo_run as run

from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate
from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate
from nemo.collections.llm.recipes import * # noqa

__all__.extend(
Expand All @@ -220,6 +220,7 @@
"validate",
"finetune",
"generate",
"ptq",
]
)
except ImportError as error:
Expand Down
119 changes: 113 additions & 6 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
Expand Down Expand Up @@ -61,7 +62,8 @@ def train(
resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
or an instance of TokenizerSpec.
export (Optional[str]): Filename to save the exported checkpoint after training.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.

Expand All @@ -77,7 +79,7 @@ def train(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> train(model, data, trainer, tokenizer="data")
>>> llm.train(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
Expand Down Expand Up @@ -179,7 +181,7 @@ def finetune(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> finetune(model, data, trainer, peft=llm.peft.LoRA()])
>>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()])
PosixPath('/path/to/log_dir')
"""

Expand Down Expand Up @@ -217,7 +219,8 @@ def validate(
resume (Optional[AutoResume]): Resume from a checkpoint for validation.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
or an instance of TokenizerSpec.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.

Returns:
Expand All @@ -230,7 +233,7 @@ def validate(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> validate(model, data, trainer, tokenizer="data")
>>> llm.validate(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
Expand All @@ -249,6 +252,109 @@ def validate(
return app_state.exp_dir


@run.cli.factory
@run.autoconvert
def default_quantization(
algorithm: str = "fp8",
awq_block_size: int = 128,
sq_alpha: float = 0.5,
enable_kv_cache: Optional[bool] = None,
calibration_dataset: str = "cnn_dailymail",
calibration_dataset_size: int = 512,
calibration_batch_size: int = 64,
calibration_seq_len: int = 128,
) -> QuantizationConfig:
"""Default quantization configuration."""
return QuantizationConfig(
algorithm=algorithm,
awq_block_size=awq_block_size,
sq_alpha=sq_alpha,
enable_kv_cache=enable_kv_cache,
calibration_dataset=calibration_dataset,
calibration_dataset_size=calibration_dataset_size,
calibration_batch_size=calibration_batch_size,
calibration_seq_len=calibration_seq_len,
)


@run.cli.factory
@run.autoconvert
def default_export(
path: Union[Path, str] = "./qnemo",
dtype: Union[str, int] = "bf16",
decoder_type: Optional[str] = None,
inference_tensor_parallel: int = 1,
inference_pipeline_parallel: int = 1,
) -> ExportConfig:
"""Default export configuration."""
return ExportConfig(
path=path,
dtype=dtype,
decoder_type=decoder_type,
inference_tensor_parallel=inference_tensor_parallel,
inference_pipeline_parallel=inference_pipeline_parallel,
)


@run.cli.entrypoint(name="ptq", namespace="llm")
def ptq(
nemo_checkpoint: str,
calib_tp: int = 1,
calib_pp: int = 1,
quantization_config: QuantizationConfig = default_quantization(),
export_config: ExportConfig = default_export(),
) -> Path:
"""
Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs
calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method.

This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules
or direcly using TensorRT-LLM library.

The function can be used through the NeMo CLI in the following way:

```bash
# Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2
nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \
export_config.path=/models/Llama-3-70B-FP8 \
calib_tp=8 \
export_config.inference_tensor_parallel=2

# Choose different quantization method, for example, INT8 SmoothQuant
nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \
export_config.path=/models/Llama-3-8B-INT8_SQ \
quantization_config.algorithm=int8_sq
```

Args:
nemo_checkpoint (str): The path to model to be quantized.
calib_tp (int): Calibration tensor parallelism.
calib_pp (int): Calibration pipeline parallelism.
quantization_config (QuantizationConfig): Configuration for quantization algorithm.
export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint.

Returns:
Path: The path where the quantized checkpoint has been saved after calibration.
"""
if export_config.path is None:
raise ValueError("The export_config.path needs to be specified, got None.")

from nemo.collections.llm import quantization

quantizer = quantization.Quantizer(quantization_config, export_config)

model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp)

model = quantizer.quantize(model)

quantizer.export(model)

console = Console()
console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]")

return export_config.path
janekl marked this conversation as resolved.
Show resolved Hide resolved


def get_trtllm_deployable(
nemo_checkpoint,
model_type,
Expand Down Expand Up @@ -348,7 +454,8 @@ def deploy(
if triton_port == rest_service_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py
# Store triton ip, port and other args relevant for REST API
# in config.json to be accessible by rest_model_api.py
store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response)

triton_deployable = get_trtllm_deployable(
Expand Down
22 changes: 15 additions & 7 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -73,17 +74,20 @@ class QuantizationConfig:

@dataclass
class ExportConfig:
"""Inference configuration for the quantized TensorRT-LLM engine"""
"""Inference configuration for the quantized TensorRT-LLM checkpoint."""

path: str
path: Union[Path, str]
dtype: Union[str, int] = "bf16"
decoder_type: Optional[str] = None
inference_tensor_parallel: int = 1
inference_pipeline_parallel: int = 1

def __post_init__(self):
self.path = Path(self.path)


def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:
"""Infers the modelopt decoder type from GPTConfig class"""
"""Infers the modelopt decoder type from GPTConfig class."""
mapping = [
(llm.Baichuan2Config, "baichuan"),
(llm.ChatGLMConfig, "chatglm"),
Expand All @@ -107,17 +111,17 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:


class Quantizer:
"""Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints.
"""Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints.

PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving.
The process consist of several steps:

1. Loading a Nemo model from disk using appropriate parallelism strategy
2. Calibrating the model to obtain appropriate algorithm-specific scaling factors
3. Producing output directory
3. Producing an output directory with a quantized checkpoint and a tokenizer

The output directory produced is intended to be consumed by TensorRT-LLM toolbox
for efficient inference. This can be achieved using NeMo inference containers.
for efficient inference. This can be achieved using nemo.export.tensorrt_llm module.
"""

def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig):
Expand Down Expand Up @@ -229,6 +233,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):
def create_megatron_forward_loop(
self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None
):
"""Create a forward loop for over a given data iterator."""
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func

forward_backward_func = get_forward_backward_func()
Expand Down Expand Up @@ -260,6 +265,7 @@ def loop(model):
return loop

def export(self, model: llm.GPTModel) -> None:
"""Export model to a TensorRT-LLM checkpoint."""
assert self.export_config is not None, "Export config is not set"
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
Expand Down Expand Up @@ -291,7 +297,7 @@ def export(self, model: llm.GPTModel) -> None:
def get_calib_data_iter(
data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512
):
"""Creates a sample data iterator for calibration"""
"""Creates a sample data iterator for calibration."""
if data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
Expand All @@ -311,6 +317,8 @@ def get_calib_data_iter(


def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size):
"""Create a function that provides iterator over a given dataset."""

def _iterator():
CHARACTERS_PER_TOKEN = 4

Expand Down
Loading