-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[https://nvbugs/5378031] [feat] Hopper W4A8 MoE supports ModelOpt ckpt for PyT backend #6200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
""" WalkthroughSupport for a new weight loading mode, Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Deepseek as Deepseekv3MoE
participant MoE as create_moe
participant Quant as WInt4AFP8FusedMoEMethod
Test->>Deepseek: Initialize with quantization config
Deepseek->>MoE: create_moe(weight_loading_mode)
MoE->>Quant: load_expert_weights_to_dst(..., weight_loading_mode)
Quant->>Quant: load_expert_w3_w1_weight(..., weight_loading_mode)
Quant->>Quant: load_expert_w2_weight(..., weight_loading_mode)
Quant->>Quant: load_quant_scales(..., weight_loading_mode)
Test->>Test: Validate outputs for selected weight_loading_mode
Suggested labels
Suggested reviewers
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)
1024-1257
: LGTM! Well-implemented W4A8 AWQ quantization method.The implementation correctly follows the ModelOpt W4A8 AWQ quantization flow:
- Multiplies pre_quant_scale to input
- Quantizes input to FP8 using input_scale
- Unpacks weights and multiplies by weight_scales (int4 → fp16)
- Divides by weight_scale_2 (fp16 → fp8 for GEMM)
- Applies GEMM in FP8
- Rescales using alpha (input_scale * weight_scale_2)
The weight scale handling and alpha computation are correct.
Consider extracting common logic between W4A16_AWQ_LinearMethod and W4A8_AWQ_LinearMethod into a base class to reduce code duplication in future refactoring efforts.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp
(3 hunks)cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h
(1 hunks)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
(3 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py
(2 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py
(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)tensorrt_llm/_torch/modules/linear.py
(10 hunks)tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py
(1 hunks)tests/unittest/_torch/thop/test_w4a16_gemm.py
(0 hunks)tests/unittest/_torch/thop/test_w4a16_linear.py
(3 hunks)tests/unittest/_torch/thop/test_w4a8_linear.py
(1 hunks)
💤 Files with no reviewable changes (1)
- tests/unittest/_torch/thop/test_w4a16_gemm.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (2)
tests/unittest/utils/util.py (2)
woq_assert_near_eq
(382-392)woq_groupwise_gt_matmul
(395-399)tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
FinegrainedMixedDtypeGemm
(678-717)finegrained_mixed_dtype_gemm
(722-764)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
FinegrainedMixedDtypeGemm
(678-717)finegrained_mixed_dtype_gemm
(722-764)
tensorrt_llm/_torch/modules/linear.py (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (7)
finegrained_mixed_dtype_gemm
(722-764)_
(216-255)_
(334-342)_
(423-433)_
(605-632)_
(665-675)_
(884-971)tensorrt_llm/quantization/mode.py (2)
is_int4_weight_only_per_group
(129-130)QuantAlgo
(23-44)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py
120-120: Line too long (135 > 120)
(E501)
1105-1105: Line too long (132 > 120)
(E501)
1204-1204: Line too long (143 > 120)
(E501)
1205-1205: Line too long (165 > 120)
(E501)
1251-1251: Line too long (142 > 120)
(E501)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
786-786: Line too long (132 > 120)
(E501)
880-880: Line too long (132 > 120)
(E501)
🔇 Additional comments (33)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
12-18
: LGTM: Clean enum extension with proper documentation.The addition of
CUSTOM_W4A8
to theMoEWeightLoadingMode
enum is well-implemented with:
- Clear descriptive comments for all enum values
- Sequential value assignment (0, 1, 2)
- Specific reference to the quantization script that generates these weights
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
57-58
: LGTM: Proper import addition.The import of
MoEWeightLoadingMode
is correctly placed and follows the existing import structure.tests/unittest/_torch/thop/test_w4a16_linear.py (3)
6-7
: LGTM: Import updated correctly.The import change from the old
W4A16GemmRunner
toFinegrainedMixedDtypeGemm
is appropriate for the refactored API.
20-23
: LGTM: SM version check and skip message updated appropriately.The changes correctly use the new class constant and update the skip message to reflect the broader W4A16/W4A8 support.
77-85
: LGTM: Operator call correctly updated for new API.The parameter mappings are correct:
input=x.contiguous()
(clear and explicit)scales=weight_scale.type(x.dtype)
(wasweight_scale
)group_size=GROUP_SIZE
(wasGROUP_SIZE
)has_zero_point=has_zero
(washas_zero
)output_dtype=x.dtype
(new required parameter)zeros=None
(unchanged)The use of keyword arguments improves clarity and maintainability.
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (4)
11-47
: Comprehensive test parameterization covers key scenarios.The test parameters appropriately cover:
- Various matrix dimensions (small to large)
- Different group sizes (64, 128)
- Multiple activation dtypes (bfloat16, float16)
- All quantization flags combinations
- Both W4A16 and W4A8 AWQ modes
78-82
: Conditional activation type logic is correct.The logic correctly sets
activation_type
totorch.float8_e4m3fn
for W4A8 AWQ mode and uses the originalactivation_dtype
for W4A16 mode.
102-114
: Operator call correctly implements new API.The
finegrained_mixed_dtype_gemm
call properly uses:
- Keyword arguments for clarity
- Conditional FP8 input conversion for W4A8 AWQ
- Explicit
output_dtype
parameter- Optional
alpha
parameter for W4A8 AWQ- Comprehensive parameter coverage
116-122
: Reference computation and validation logic is sound.The test correctly:
- Applies FP8 alpha scaling for W4A8 AWQ mode
- Uses groupwise ground truth matrix multiplication
- Validates with appropriate tolerance (2)
tests/unittest/_torch/thop/test_w4a8_linear.py (4)
41-46
: W4A8-specific parameter setup is correct.The test correctly includes W4A8-specific parameters:
weight_scale
usestorch.float16
(appropriate for W4A8)weight_scale_2
andinput_scale
for FP8 scaling- Parameter types and device placement are consistent
49-66
: Linear layer configuration and weight loading is proper.The test correctly:
- Uses
QuantAlgo.W4A8_AWQ
quantization algorithm- Loads all required W4A8 parameters including the new scaling factors
- Clones weight tensor to avoid modification issues
70-75
: Weight preprocessing uses correct activation type.The preprocessing correctly uses
torch.float8_e4m3fn
as the activation type for W4A8, and the weight comparison validates the preprocessing worked correctly.
84-98
: Reference computation correctly implements W4A8 flow.The reference implementation properly:
- Applies pre-quantization scaling
- Uses static quantization to FP8 E4M3
- Computes combined alpha scaling factor
- Calls the new GEMM operator with correct parameters
- Adjusts weight scales by dividing by
weight_scale_2
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1)
27-44
: LGTM! Well-structured API extension.The class renaming and API extensions are appropriate:
- The new name
finegrainedMixedDtypeGemmRunner
better reflects the mixed-dtype capabilities- Adding
outputDtype
parameter provides flexibility for different output precisions- The
alpha
parameter with default value maintains backward compatibilitytests/unittest/_torch/modules/test_fused_moe.py (4)
648-649
: Good test coverage for both weight loading modes.The parameterization properly tests both the ModelOpt checkpoint format (VANILLA) and the custom W4A8 format.
667-678
: Appropriate handling of weight scale key differences.The lookup table correctly maps
weight_scale_inv
for CUSTOM_W4A8 mode, reflecting the different weight scale representations between the two formats.
770-776
: Correct unpacking logic for different packing strategies.The unpacking function properly handles the different packing dimensions:
- VANILLA: transposed unpacking for output dimension packing
- CUSTOM_W4A8: direct unpacking and transpose for input dimension packing
878-883
: Comprehensive validation checks.Good additions:
- NaN checks ensure numerical stability
- Non-zero output validation confirms the computation produces meaningful results
- Debug prints help troubleshooting
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (4)
47-143
: Well-implemented constructor with proper dtype validation.The constructor correctly:
- Validates activation/output dtype combinations
- Supports Float8_e4m3fn activation with Half/BFloat16 outputs
- Enforces matching dtypes for Half/BFloat16 activations
- Provides clear error messages for unsupported configurations
216-228
: Correct output dtype handling.The output tensor creation properly uses
mOutputDtype
instead of the activation dtype, enabling mixed-precision GEMM operations.
260-261
: Appropriate alpha parameter handling.The alpha value is correctly converted from double to float and passed to the underlying GEMM runner.
274-280
: Correct Torch library registration.The registration properly exposes the renamed class with its updated constructor signature and methods.
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
678-717
: Well-implemented Python wrapper for the C++ runner.The class correctly:
- Uses appropriate instance key for caching runners
- Wraps the renamed C++ class
- Handles optional alpha parameter with proper default value
720-764
: Clean operator implementation with proper validation.The new operator:
- Has a more descriptive name reflecting its mixed-dtype capabilities
- Properly validates zeros tensor presence when required
- Maintains AutoTuner integration for performance optimization
- Correctly handles all parameters including the new alpha value
tensorrt_llm/_torch/modules/linear.py (5)
50-55
: LGTM! Clean implementation of tensor parallel mode toggling.The
flip
method correctly switches between ROW and COLUMN parallel modes, which is necessary for sharding activation scales along the dimension orthogonal to weight sharding.
119-126
: LGTM! Correct handling of both W4A16 and W4A8 quantization methods.The code properly selects the activation dtype based on the quantization method (FP16 for W4A16, FP8 for W4A8) and applies the necessary preprocessing for mixed-precision GEMM.
902-903
: LGTM! Correct weight scale tensor layout and improved GEMM operator usage.The changes properly transpose weight scale tensors to match the expected layout and use the new
finegrained_mixed_dtype_gemm
operator with explicit named parameters. The use ofTensorParallelMode.flip
for pre_quant_scale sharding is appropriate since it applies to activations.Also applies to: 918-931, 962-982, 998-1022
1276-1279
: LGTM! Correct extension of quantization method selection.The function properly returns
W4A8_AWQ_LinearMethod
when the quantization mode is int4 weight-only per-group and the algorithm is W4A8_AWQ.
1403-1408
: LGTM! Consistent property implementation.The
has_w4a8_awq
property correctly identifies W4A8 AWQ quantization configuration, maintaining consistency with the existinghas_w4a16_awq
property.tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)
99-102
: LGTM! Correct handling of CUSTOM_W4A8 weight loading mode.The CUSTOM_W4A8 mode appropriately uses the same weight key structure as VANILLA mode with individual w1, w2, and w3 weights.
583-584
: LGTM! Improved documentation of scaling factors.The updated comments clearly explain the purpose of
fc31_act_scale
(reciprocal of per-channel pre_quant_scale * per-tensor input_scale) andalpha
(weight_scale_2 * input_scale for rescaling GEMM output).Also applies to: 614-615
679-710
: LGTM! Clean implementation of architecture-specific weight loading.The code correctly handles different weight formats:
- SM89: Uses existing preprocessing logic
- SM90 VANILLA: Unpacks int4 weights, transposes, and repacks for the expected layout
- SM90 CUSTOM_W4A8: No-op, indicating weights are pre-formatted correctly
The extraction of packer/unpacker operations improves code clarity.
Also applies to: 725-755
759-929
: LGTM! Comprehensive handling of different checkpoint formats.The implementation correctly handles two checkpoint formats:
Custom W4A8 format:
- Uses "weight_scale_inv" parameter name
- Fuses per-tensor input_scale with per-channel pre_quant_scale
- Fuses weight_scale_2 into alpha
Vanilla (ModelOpt) format:
- Uses "weight_scale" parameter name
- Keeps scales separate and combines them during loading
- Divides weight scales by weight_scale_2_max before storage
The scale computations are mathematically correct for both formats.
e080b5c
to
636b31e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)
583-583
: Improve comment clarity for activation scale parameter.The comment could be more precise about the mathematical operation being performed.
- # Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale + # Scale activation by reciprocal of (per-channel pre_quant_scale * per-tensor input_scale)
614-614
: Improve comment clarity for weight scale parameter.Similar to the previous comment, this could be more mathematically precise.
- # Multiply W@X with per-tensor weight_scale_2 * per-tensor input_scale. + # Scale matrix multiplication result by (per-tensor weight_scale_2 * per-tensor input_scale)
786-786
: Fix line length violations.Two lines exceed the 120-character limit as flagged by static analysis.
- # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + # In vanilla ckpt (ModelOpt format), per-tensor input_scale and + # per-channel pre_quant_scale are separately stored- # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + # In vanilla ckpt (ModelOpt format), per-tensor input_scale and + # per-channel pre_quant_scale are separately storedAlso applies to: 880-880
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
(2 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py
(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/modules/fused_moe/interface.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
786-786: Line too long (132 > 120)
(E501)
880-880: Line too long (132 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (18)
tests/unittest/_torch/modules/test_fused_moe.py (9)
24-24
: LGTM: Import addition is appropriate.The
MoEWeightLoadingMode
import aligns with the new parameterized test functionality.
648-649
: LGTM: Test parameterization correctly implemented.The test function signature update and parameterization with both
VANILLA
andCUSTOM_W4A8
modes provides comprehensive coverage for the new weight loading functionality.
667-677
: LGTM: Dynamic weight attribute lookup table is well-designed.The lookup table approach elegantly handles the different weight scale attribute names between the two modes, making the test code more maintainable.
681-692
: LGTM: Weight shape logic correctly differentiates packing strategies.The conditional weight shapes properly reflect the different packing strategies:
VANILLA
: Packs 4-bit weights along output dimensionCUSTOM_W4A8
: Packs 4-bit weights along input dimensionThis matches the implementation differences described in the PR summary.
694-697
: LGTM: Weight initialization uses appropriate data type and range.Using
torch.int8
with range[-128, 127]
for the quantized weights is correct for 4-bit packed representations.
752-753
: LGTM: Model configuration correctly includes weight loading mode.The
weight_loading_mode
parameter is properly passed to theCutlassFusedMoE
constructor, ensuring the test validates the actual runtime configuration.
770-778
: LGTM: Weight unpacking logic correctly handles both modes.The
unpack_weights
helper function properly handles the different tensor layouts:
VANILLA
: Transpose then unpackCUSTOM_W4A8
: Unpack then transposeThis correctly inverts the packing strategies used during weight preparation.
810-839
: LGTM: Quantization process helper function is well-structured.The
process_layer
function encapsulates the complex quantization logic clearly, with proper handling of optional scale parameters that are only used inVANILLA
mode.
874-883
: LGTM: Enhanced output validation with comprehensive checks.The additional validation checks are valuable:
- NaN detection in both outputs
- Non-empty result verification
- Debug output for troubleshooting
This significantly improves test reliability and debugging capabilities.
tensorrt_llm/_torch/modules/fused_moe/quantization.py (9)
99-102
: LGTM: Conditional check correctly includes new weight loading mode.The addition of
MoEWeightLoadingMode.CUSTOM_W4A8
to the existing conditional maintains backward compatibility while enabling the new functionality.
679-681
: LGTM: Pack/unpack operations are properly defined.The tensor pack/unpack operations are correctly assigned and will be used consistently throughout the weight loading process.
694-709
: LGTM: SM90 weight preprocessing logic is well-structured.The conditional logic properly handles the different weight formats:
VANILLA
: ModelOpt format requiring transpose/unpack/transpose/pack sequenceCUSTOM_W4A8
: Custom format requiring no preprocessingThe comments clearly explain the tensor shape transformations.
739-754
: LGTM: Consistent weight preprocessing for w2 weights.The w2 weight preprocessing follows the same pattern as w3_w1 weights, maintaining consistency in the implementation.
759-764
: LGTM: Scale naming abstraction is clean.The dynamic scale name selection based on weight loading mode creates a clean abstraction that reduces code duplication in the scale loading logic.
850-851
: LGTM: Scale division for VANILLA mode is properly conditioned.The weight scale division by
weight_scale_2
is correctly applied only in VANILLA mode, maintaining the different scaling semantics between the two checkpoint formats.
871-902
: LGTM: FC2 scale computation mirrors FC31 logic.The fc2 scale computation correctly mirrors the fc31 (w3_w1) scale computation pattern, with appropriate handling of per-channel vs per-tensor scales for both modes.
918-919
: LGTM: Consistent scale division pattern for w2 weights.The w2 scale division follows the same pattern as w3_w1 scales, maintaining consistency in the VANILLA mode scale handling.
778-827
: Verified VANILLA mode scale computation is correctI’ve compared the code in
quantization.py
against the reference logic intests/unittest/_torch/modules/test_fused_moe.py
. In both implementations you:
- Compute
pre_quant_scale_max = max(a1, a3)
andinput_scale_max = max(p3, p1)
.- Compute
weight_scale_2_max = max(q3, q1)
.- In production code:
fc31_act_scale = pre_quant_scale_max * (1 / input_scale_max)
fc31_alpha = weight_scale_2_max * input_scale_max
- In the test’s
process_layer
, this corresponds exactly to multiplying activations bypre_quant_scale
then dividing byinput_scale
, and later multiplying outputs byinput_scale
andweight_scale_2
.No discrepancies found—no changes required here.
636b31e
to
42da122
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/unittest/_torch/modules/test_fused_moe.py (2)
810-839
: Consider cleaning up debug prints.The
process_layer
function correctly handles the quantization logic with conditional scaling. However, consider removing the commented debug print statements for production code.- # print(f"DEBUG: {layer_name}.max() {abs(act.max())}") - # print(f"DEBUG: {layer_name}_pre_quant.max() {abs(act.max())}") - # print(f"DEBUG: {layer_name}_quantized.max() {abs(act.max())}") - # print(f"DEBUG: {layer_name}_output_input_scale.max() {abs(output.max())}") - # print(f"DEBUG: {layer_name}_output_weight_scale_2.max() {abs(output.max())}")
874-883
: Improve validation approach and consider output verbosity.The enhanced validation with NaN checks and non-empty assertions is valuable, but consider these improvements:
- The debug prints will clutter test output - consider conditional printing based on test failure
- Consider keeping
torch.testing.assert_close
as the primary assertion method for better error messages- success = torch.allclose(output, ref_output, rtol=1e-2, atol=0.1) - print(f"ref_output: {ref_output}") - print(f"output: {output}") - # assert that result does not contain NaN ref_has_nan = torch.isnan(ref_output).any() out_has_nan = torch.isnan(output).any() assert not ref_has_nan, "ref_output contains NaN" assert not out_has_nan, "output contains NaN" assert torch.nonzero(output).numel() != 0 and torch.nonzero(ref_output).numel() != 0 - assert success + + # Use torch.testing.assert_close for better error messages + try: + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + except AssertionError: + print(f"ref_output: {ref_output}") + print(f"output: {output}") + raise
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
(2 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py
(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/modules/fused_moe/interface.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
786-786: Line too long (132 > 120)
(E501)
880-880: Line too long (132 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (17)
tests/unittest/_torch/modules/test_fused_moe.py (7)
24-24
: LGTM: Import addition is correct.The
MoEWeightLoadingMode
import is necessary for the new parameterized test functionality.
648-649
: LGTM: Parametrization correctly adds support for both weight loading modes.The test now properly covers both
VANILLA
andCUSTOM_W4A8
modes as intended for the ModelOpt checkpoint format support.
667-677
: LGTM: Lookup table provides clean abstraction for mode-specific keys.The dynamic key selection based on
weight_loading_mode
is well-implemented and maintains code readability.
694-743
: LGTM: Weight initialization correctly handles both modes.The weight initialization logic properly uses the lookup table and creates appropriately shaped tensors for each mode. The code is clean and maintainable.
752-753
: LGTM: Constructor parameter addition is correct.The
weight_loading_mode
parameter is properly passed to theCutlassFusedMoE
constructor.
681-692
: Weight shape logic consistent with quantization implementationAll inspected shape definitions in tests/unittest/_torch/modules/test_fused_moe.py align with the packing logic in tensorrt_llm/_torch/modules/fused_moe/quantization.py for both VANILLA and CUSTOM_W4A8 modes. No discrepancies found—no further changes needed.
770-776
: Unpacking logic verifiedThe
unpack_weights
helper in the test correctly mirrors the operator usage in the quantization module:
- VANILLA: both the implementation (SM90 branch) and test perform a transpose before unpacking.
- CUSTOM_W4A8: both the SM89 branch and test unpack first and then transpose.
No changes needed here.
tensorrt_llm/_torch/modules/fused_moe/quantization.py (10)
99-102
: LGTM: Correctly extends weight loading mode support.The addition of
MoEWeightLoadingMode.CUSTOM_W4A8
to the conditional check properly extends support while maintaining existing functionality.
583-593
: LGTM: New activation scale parameters are well-defined.The
fc31_act_scale
andfc2_act_scale
parameters are properly registered with appropriate shapes and dtypes. The comments clearly explain their purpose.
614-625
: LGTM: Alpha parameters are correctly defined.The
fc31_alpha
andfc2_alpha
parameters are properly registered with appropriate dtypes and clear documentation of their purpose.
738-754
: LGTM: Consistent weight loading logic for w2 weights.The
load_expert_w2_weight
method correctly implements the same processing pattern as the w3_w1_weight method, ensuring consistency across weight types.
759-764
: LGTM: Clean approach to handle different scale key names.The conditional selection of
weight_scale_name
based on the loading mode provides a clean abstraction for the different key naming conventions.
871-902
: LGTM: Consistent fc2 scale handling.The fc2 scale loading follows the same pattern as fc31, maintaining consistency in the implementation approach across different layer types.
850-851
: LGTM: Conditional weight scale normalization is correct.The weight scale normalization is correctly applied only in
VANILLA
mode, consistent with the different scale handling approaches between modes.
918-919
: LGTM: Consistent normalization logic for fc2 weights.The fc2 weight scale normalization follows the same conditional pattern, maintaining consistency with the fc31 handling.
693-709
: Verify SM90 weight processing sequenceThe SM90 branches for
VANILLA
(transpose→unpack→transpose→pack) andCUSTOM_W4A8
(pass-through) appear consistent with the ModelOpt quantization strategy, but we currently have no automated validation for SM90. Please:
- Confirm that, for both
w2
andw31
shards in
tensorrt_llm/_torch/modules/fused_moe/quantization.py
(lines ~693–709), the sequence
transpose → unpack → transpose → pack yields the correct final layout and dtype.- Add unit tests in
tests/unittest/_torch/modules/test_fused_moe.py
covering:
- SM90 +
MoEWeightLoadingMode.VANILLA
produces the expected shapes/data.- SM90 +
MoEWeightLoadingMode.CUSTOM_W4A8
retains the original shard unchanged.
778-827
: Scale computation logic confirmed correctAfter cross-referencing the
fc31
logic with the patterns inlinear.py
,quantization_utils.py
, and the unit tests intest_fused_moe.py
, the two branches correctly implement:
- CUSTOM_W4A8:
• fc31_act_scale = 1 / input_scale_max
• fc31_alpha = input_scale_max- VANILLA:
• fc31_act_scale = pre_quant_scale_max / input_scale_max
• fc31_alpha = weight_scale_2_max * input_scale_maxThese match the intended quantization formulas (act_scale = pre_quant_scale·input_scale⁻¹, and alpha = input_scale·weight_scale_2) used elsewhere in the codebase and tests. No changes required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_moe.py (1)
874-883
: Enhanced output validation with comprehensive checks.The validation improvements are excellent:
- Manual
allclose
check allows for additional custom validations- NaN checks help detect numerical issues in the quantization pipeline
- Non-empty assertions ensure meaningful output data
- Comprehensive error detection for debugging
Consider making the debug prints conditional to avoid cluttering test output:
- print(f"ref_output: {ref_output}") - print(f"output: {output}") + if not success: + print(f"ref_output: {ref_output}") + print(f"output: {output}")
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/modules/test_fused_moe.py (8)
24-24
: LGTM: Import addition supports new test functionality.The
MoEWeightLoadingMode
import is correctly added to support the new test parameterization for different weight loading modes.
648-649
: LGTM: Test parameterization correctly covers both weight loading modes.The addition of
weight_loading_mode
parameter ensures comprehensive testing of both VANILLA and CUSTOM_W4A8 modes, aligning with the PR's goal to support ModelOpt checkpoint format.
667-677
: LGTM: Clean lookup table implementation for mode-specific weight keys.The LUT approach elegantly handles the different naming conventions between VANILLA and CUSTOM_W4A8 modes, particularly for the weight scale keys where CUSTOM_W4A8 uses "weight_scale_inv" instead of "weight_scale".
681-692
: LGTM: Correct implementation of mode-specific weight packing strategies.The weight shape logic correctly handles the different packing approaches:
- VANILLA mode packs 4-bit weight pairs along the output dimension
- CUSTOM_W4A8 mode packs 4-bit weight pairs along the input dimension
This properly reflects the differences between ModelOpt and custom quantization script formats.
693-726
: LGTM: Proper initialization of quantized weights and scaling tensors.The weight initialization correctly:
- Uses int8 tensors for quantized weights with valid range [-128, 127]
- Introduces necessary scaling tensors (pre_quant_scale, weight_scale_2, input_scale)
- Sets appropriate data types and devices for all tensors
- Uses reasonable scale ranges for testing purposes
727-742
: LGTM: Weight dictionary correctly populated using LUT.The weight dictionary setup properly uses the lookup table to ensure correct attribute names for each weight loading mode, maintaining compatibility with both VANILLA and CUSTOM_W4A8 formats.
752-753
: LGTM: Model configuration correctly passes weight loading mode.The addition of
weight_loading_mode
parameter to the CutlassFusedMoE constructor ensures the model uses the appropriate weight loading and processing logic for the selected mode.
770-860
: LGTM: Comprehensive reference implementation supporting both weight loading modes.The reference implementation correctly handles both modes:
unpack_weights
properly unpacks tensors based on the packing strategyprocess_layer
encapsulates the quantization pipeline with conditional scaling- Pre-quantization and weight_scale_2 are correctly applied only in VANILLA mode
- The forward pass logic accurately reflects the expected behavior for both modes
0872046
to
d8244f8
Compare
/bot run |
PR_Github #12386 [ run ] triggered by Bot |
d8244f8
to
e1cc336
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_moe.py (1)
774-899
: Comprehensive reference implementation with room for improvement.The reference implementation correctly handles both weight loading modes with appropriate unpacking strategies and conditional scale applications. The enhanced validation logic provides good debugging capabilities.
Minor suggestions for improvement:
- print(f"ref_output: {ref_output}") - print(f"output: {output}") + # Debug prints can be enabled when needed + # print(f"ref_output: {ref_output}") + # print(f"output: {output}")Consider simplifying the assertion logic:
- success = torch.allclose(output, ref_output, rtol=1e-2, atol=0.1) - # assert that result does not contain NaN - ref_has_nan = torch.isnan(ref_output).any() - out_has_nan = torch.isnan(output).any() - assert not ref_has_nan, "ref_output contains NaN" - assert not out_has_nan, "output contains NaN" - assert torch.nonzero(output).numel() != 0 and torch.nonzero( - ref_output).numel() != 0 - assert success + # Comprehensive validation + assert not torch.isnan(ref_output).any(), "ref_output contains NaN" + assert not torch.isnan(output).any(), "output contains NaN" + assert torch.nonzero(output).numel() > 0, "output is empty" + assert torch.nonzero(ref_output).numel() > 0, "ref_output is empty" + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (6)
tests/unittest/_torch/modules/test_fused_moe.py (6)
21-23
: LGTM! Import addition is correct.The import of
MoEWeightLoadingMode
is properly added and necessary for the new test parameter functionality.
645-648
: Good test coverage enhancement.The parameterization of
weight_loading_mode
with bothVANILLA
andCUSTOM_W4A8
modes ensures comprehensive testing coverage for the new weight loading functionality.
666-678
: Clean approach to handle different weight attribute naming.The lookup table pattern effectively abstracts the differences in weight attribute names between the two loading modes, making the code more maintainable.
682-693
: Correct implementation of different weight packing strategies.The weight shape handling correctly accounts for the different packing approaches:
- VANILLA mode: Output dimension packing (ModelOpt W4A8)
- CUSTOM_W4A8 mode: Input dimension packing (custom quantization script)
The clear comments explain the rationale behind each approach.
694-746
: Proper weight initialization with lookup table usage.The weight initialization correctly:
- Uses the lookup table for consistent attribute naming across modes
- Includes all necessary tensors (pre_quant_scale, weight_scale_2) for comprehensive testing
- Maintains proper tensor shapes and device placement
756-757
: Correct parameter propagation to model.The
weight_loading_mode
parameter is properly passed to theCutlassFusedMoE
constructor, ensuring the test configuration matches the intended behavior.
PR_Github #12386 [ run ] completed with state |
/bot run |
PR_Github #12395 [ run ] triggered by Bot |
ebbbb6f
to
e02c34c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/unittest/_torch/modules/test_fused_moe.py (1)
891-899
: Consider removing debug prints for production.The defensive checks for NaN values and empty tensors are excellent additions that prevent silent failures. However, consider removing the debug print statements (lines 891-892) before merging to avoid cluttering the test output.
- print(f"ref_output: {ref_output}") - print(f"output: {output}")tensorrt_llm/_torch/modules/fused_moe/quantization.py (2)
778-828
: Address line length and verify scale computation logicThe fc31 activation scale computation logic correctly handles the differences between CUSTOM_W4A8 and VANILLA modes:
- CUSTOM_W4A8: Fused scales, simpler computation
- VANILLA: Separate scales requiring more complex handling
However, there are formatting issues to address.
Apply this diff to fix line length issues:
- # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and + # per-channel pre_quant_scale are separately stored
871-903
: Address line length and verify fc2 scale logicThe fc2 activation scale computation follows the same pattern as fc31, correctly handling the mode differences. The logic appears sound but has formatting issues.
Apply this diff to fix line length issues:
- # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored + # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and + # per-channel pre_quant_scale are separately stored
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_deepseekv3.py
(2 hunks)tensorrt_llm/_torch/modules/fused_moe/interface.py
(1 hunks)tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)tensorrt_llm/_torch/modules/linear.py
(4 hunks)tests/unittest/_torch/modules/test_fused_moe.py
(7 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/modules/fused_moe/interface.py
- tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
786-786: Line too long (132 > 120)
(E501)
880-880: Line too long (132 > 120)
(E501)
🔇 Additional comments (18)
tensorrt_llm/_torch/modules/linear.py (4)
964-966
: LGTM! Improved device handling for multi-GPU setups.The dynamic device assignment using
module.weight.device
instead of hardcoded CUDA device is a good improvement. The comments clearly explain the rationale for this change.
1146-1148
: LGTM! Consistent device handling improvement.The dynamic device assignment is consistent with the W4A16 implementation and properly handles multi-GPU scenarios.
1211-1213
: LGTM! Consistent dynamic device usage.The device assignment follows the same pattern as other methods, and line 1220 correctly uses the dynamic device variable instead of hardcoded CUDA device.
Also applies to: 1220-1220
1255-1257
: LGTM! Completes the consistent device handling pattern.The final set of changes maintains consistency with the other methods in the class, properly using dynamic device detection throughout the
W4A8_AWQ_LinearMethod
class.Also applies to: 1264-1264
tests/unittest/_torch/modules/test_fused_moe.py (6)
21-26
: LGTM! Proper import addition.The addition of
MoEWeightLoadingMode
import is necessary for the new test functionality, and the isort comments indicate proper import organization.
648-651
: LGTM! Comprehensive test coverage for both weight loading modes.The parametrization properly tests both
VANILLA
andCUSTOM_W4A8
modes, ensuring comprehensive coverage of the new functionality.
669-681
: LGTM! Clean lookup table implementation.The lookup table provides a maintainable way to handle different weight attribute names between the two modes, making the test code more organized and readable.
684-696
: LGTM! Correct weight shape handling for different modes.The conditional logic properly handles the different weight packing strategies:
- VANILLA mode packs pairs of 4-bit weights in the output dimension
- CUSTOM_W4A8 mode packs in the input dimension
The comments clearly explain the rationale for each approach.
697-749
: LGTM! Proper weight initialization for W4A8 quantization.The weight initialization correctly:
- Uses
torch.randint
for quantized int8 weights- Adds all necessary scale tensors (
pre_quant_scale
,input_scale
,weight_scale_2
)- Uses the lookup table for consistent key naming
- Provides proper tensor shapes and data types for W4A8 quantization
759-760
: LGTM! Comprehensive reference implementation for both modes.The reference implementation correctly:
- Passes
weight_loading_mode
to the model- Handles different weight unpacking strategies via
unpack_weights
- Conditionally applies
pre_quant_scale
andweight_scale_2
only for VANILLA mode- Encapsulates complex quantization logic in the
process_layer
helper functionThe implementation properly simulates the quantization process for both modes, providing a solid reference for comparison.
Also applies to: 777-876
tensorrt_llm/_torch/modules/fused_moe/quantization.py (8)
99-102
: LGTM: Properly extends weight loading mode supportThe addition of
CUSTOM_W4A8
to the conditional check correctly extends support for the new weight loading mode while maintaining existing functionality.
583-594
: LGTM: Appropriate parameter registration for W4A8 quantizationThe new activation scale parameters are correctly registered with proper shapes and data types for the W4A8 quantization method. The comments clearly explain their purpose.
614-626
: LGTM: Alpha parameters properly configuredThe alpha parameters are correctly registered with appropriate shapes and dtype (float32) for storing scaling factors in the W4A8 quantization pipeline.
725-755
: Consistent weight preprocessing for w2 weightsThe w2 weight preprocessing logic mirrors the w3_w1 implementation with appropriate handling for different SM versions and weight loading modes. The structure is consistent and correct.
759-764
: LGTM: Clean conditional weight scale namingThe conditional assignment of
weight_scale_name
based on the weight loading mode is clean and follows the expected naming convention differences between VANILLA and CUSTOM_W4A8 modes.
918-920
: Consistent normalization for fc2 weightsThe fc2 weight scale normalization follows the same pattern as fc31, only applied in VANILLA mode. This consistency is good and the logic appears correct.
679-710
: Verify CUSTOM_W4A8 weight preprocessingI couldn’t find any documentation or code that prepares weights for the
CUSTOM_W4A8
mode. Please confirm that:
- Weights loaded under
MoEWeightLoadingMode.CUSTOM_W4A8
are already in the expected packed‐int4 format before this method.- If they aren’t, add the equivalent pack/unpack preprocessing used for the other modes.
850-852
: Division for VANILLA mode is mathematically sound
all_w3_w1_weight_scale_2_max
is computed (lines 821–826) as the element-wise maximum across all per-expertweight_scale_2
shards.- In VANILLA mode, dividing each per-group weight scale by this tensor mirrors the pattern in the core
Linear
modules (weight_scale / weight_scale_2
) and thefp4_global_scale
utility.- Unit tests in
tests/unittest/_torch/modules/test_fused_moe.py
(lines 821–823) explicitly buildq3_q1 = torch.max(q3, q1)
and verify thatweight /= weight_scale_2
, confirming correct behavior.No changes required.
PR_Github #12395 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Left some minor comments. Thanks for the contribution!
e02c34c
to
44fed18
Compare
PR_Github #14762 [ run ] triggered by Bot |
PR_Github #14752 [ run ] completed with state |
PR_Github #14762 [ run ] completed with state |
Test passed except for 1 unrelated flaky (probably) test
|
/bot run |
PR_Github #14802 [ run ] triggered by Bot |
PR_Github #14802 [ run ] completed with state |
Pinging @NVIDIA/trt-llm-torch-modules and affected file code owner @hlu1 for review |
Signed-off-by: Anthony Chang <[email protected]> fix activation pre_quant_scale sharding in w4a8/w4a16 linear Signed-off-by: Anthony Chang <[email protected]> address feedback; tidy Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
3724212
to
3b2e9bc
Compare
/bot run |
PR_Github #14977 [ run ] triggered by Bot |
PR_Github #14977 [ run ] completed with state |
after applying these changes, DeepSeek-R1-W4AFP8 fails to initialize during runtime |
Hi @Nekofish-L, thanks for reporting it. Could you try the fix #7123 and see if it works, while we try to get PR merged? |
…rom #6200 (#7123) Signed-off-by: Anthony Chang <[email protected]>
…rom NVIDIA#6200 (NVIDIA#7123) Signed-off-by: Anthony Chang <[email protected]>
Description
Hopper W4A8 MoE supports ModelOpt ckpt for PyT backend
The existing SM90 W4A8 MoE supports bespoke checkpoint format exported by
examples/quantization/quantize_mixed_precision_moe.py
The PR adds support taking checkpoints exported by ModelOpt.
This PR depends on #6005 to work on both W4A8 dense layer and W4A8 MoE layer.
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit