Skip to content

Conversation

rosenrodt
Copy link
Collaborator

@rosenrodt rosenrodt commented Jul 20, 2025

Description

Hopper W4A8 MoE supports ModelOpt ckpt for PyT backend

The existing SM90 W4A8 MoE supports bespoke checkpoint format exported by examples/quantization/quantize_mixed_precision_moe.py

The PR adds support taking checkpoints exported by ModelOpt.

This PR depends on #6005 to work on both W4A8 dense layer and W4A8 MoE layer.

Test Coverage

pytest tests/unittest/_torch/modules/test_fused_moe.py -k w4afp8 -s

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features
    • Added support for a new weight loading mode ("CUSTOM_W4A8") for quantized models, enhancing flexibility in model deployment.
  • Bug Fixes
    • Improved handling and validation of quantized weights and scales for different hardware and loading modes.
    • Ensured correct device placement for loaded tensors to support multi-GPU and distributed environments.
  • Tests
    • Expanded and refined test coverage to ensure correct behavior for both "VANILLA" and "CUSTOM_W4A8" weight loading modes.

@rosenrodt rosenrodt requested a review from a team as a code owner July 20, 2025 07:46
Copy link
Contributor

coderabbitai bot commented Jul 20, 2025

"""

Walkthrough

Support for a new weight loading mode, CUSTOM_W4A8, was added across the MoE quantization pipeline. This includes extending the MoEWeightLoadingMode enum, updating weight loading and scale computation logic for SM90 hardware, and modifying tests to handle both VANILLA and CUSTOM_W4A8 modes with appropriate weight shapes, keys, and reference computations. The Deepseekv3MoE constructor now conditionally sets this mode based on quantization configuration. Additionally, device assignment for loaded tensors was made dynamic based on module weight device.

Changes

File(s) Change Summary
tensorrt_llm/_torch/models/modeling_deepseekv3.py Extended import; updated Deepseekv3MoE constructor to pass weight_loading_mode based on quantization mode.
tensorrt_llm/_torch/modules/fused_moe/interface.py Added CUSTOM_W4A8 member to MoEWeightLoadingMode enum; added descriptive comments to enum members.
tensorrt_llm/_torch/modules/fused_moe/quantization.py Added handling for CUSTOM_W4A8 mode in weight loading and scale computation; registered new scale/alpha params; extended SM90 support.
tests/unittest/_torch/modules/test_fused_moe.py Updated test to accept weight_loading_mode parameter; adjusted weight shapes, scale keys, and reference logic for both modes; enhanced output validation.
tensorrt_llm/_torch/modules/linear.py Changed hardcoded CUDA device assignments to dynamic device retrieval from module weight tensor for loading scales and weights.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Runner
    participant Deepseek as Deepseekv3MoE
    participant MoE as create_moe
    participant Quant as WInt4AFP8FusedMoEMethod

    Test->>Deepseek: Initialize with quantization config
    Deepseek->>MoE: create_moe(weight_loading_mode)
    MoE->>Quant: load_expert_weights_to_dst(..., weight_loading_mode)
    Quant->>Quant: load_expert_w3_w1_weight(..., weight_loading_mode)
    Quant->>Quant: load_expert_w2_weight(..., weight_loading_mode)
    Quant->>Quant: load_quant_scales(..., weight_loading_mode)
    Test->>Test: Validate outputs for selected weight_loading_mode
Loading

Suggested labels

Community want to contribute

Suggested reviewers

  • tomeras91
  • Naveassaf

Poem

In the warren of code, a new mode appears,
CUSTOM_W4A8 hops in, greeted with cheers!
Weights now load with a clever new plan,
Scales and shapes handled—oh, what a span!
Tests leap with joy, both old and new,
Quantized carrots for the whole bunny crew!
🥕
"""

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@rosenrodt
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)

1024-1257: LGTM! Well-implemented W4A8 AWQ quantization method.

The implementation correctly follows the ModelOpt W4A8 AWQ quantization flow:

  1. Multiplies pre_quant_scale to input
  2. Quantizes input to FP8 using input_scale
  3. Unpacks weights and multiplies by weight_scales (int4 → fp16)
  4. Divides by weight_scale_2 (fp16 → fp8 for GEMM)
  5. Applies GEMM in FP8
  6. Rescales using alpha (input_scale * weight_scale_2)

The weight scale handling and alpha computation are correct.

Consider extracting common logic between W4A16_AWQ_LinearMethod and W4A8_AWQ_LinearMethod into a base class to reduce code duplication in future refactoring efforts.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 98428f3 and e080b5c.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (3 hunks)
  • cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1 hunks)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (3 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
  • tensorrt_llm/_torch/modules/linear.py (10 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
  • tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (1 hunks)
  • tests/unittest/_torch/thop/test_w4a16_gemm.py (0 hunks)
  • tests/unittest/_torch/thop/test_w4a16_linear.py (3 hunks)
  • tests/unittest/_torch/thop/test_w4a8_linear.py (1 hunks)
💤 Files with no reviewable changes (1)
  • tests/unittest/_torch/thop/test_w4a16_gemm.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (2)
tests/unittest/utils/util.py (2)
  • woq_assert_near_eq (382-392)
  • woq_groupwise_gt_matmul (395-399)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
  • FinegrainedMixedDtypeGemm (678-717)
  • finegrained_mixed_dtype_gemm (722-764)
tests/unittest/_torch/thop/test_w4a16_linear.py (1)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)
  • FinegrainedMixedDtypeGemm (678-717)
  • finegrained_mixed_dtype_gemm (722-764)
tensorrt_llm/_torch/modules/linear.py (2)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (7)
  • finegrained_mixed_dtype_gemm (722-764)
  • _ (216-255)
  • _ (334-342)
  • _ (423-433)
  • _ (605-632)
  • _ (665-675)
  • _ (884-971)
tensorrt_llm/quantization/mode.py (2)
  • is_int4_weight_only_per_group (129-130)
  • QuantAlgo (23-44)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/linear.py

120-120: Line too long (135 > 120)

(E501)


1105-1105: Line too long (132 > 120)

(E501)


1204-1204: Line too long (143 > 120)

(E501)


1205-1205: Line too long (165 > 120)

(E501)


1251-1251: Line too long (142 > 120)

(E501)

tensorrt_llm/_torch/modules/fused_moe/quantization.py

786-786: Line too long (132 > 120)

(E501)


880-880: Line too long (132 > 120)

(E501)

🔇 Additional comments (33)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)

12-18: LGTM: Clean enum extension with proper documentation.

The addition of CUSTOM_W4A8 to the MoEWeightLoadingMode enum is well-implemented with:

  • Clear descriptive comments for all enum values
  • Sequential value assignment (0, 1, 2)
  • Specific reference to the quantization script that generates these weights
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

57-58: LGTM: Proper import addition.

The import of MoEWeightLoadingMode is correctly placed and follows the existing import structure.

tests/unittest/_torch/thop/test_w4a16_linear.py (3)

6-7: LGTM: Import updated correctly.

The import change from the old W4A16GemmRunner to FinegrainedMixedDtypeGemm is appropriate for the refactored API.


20-23: LGTM: SM version check and skip message updated appropriately.

The changes correctly use the new class constant and update the skip message to reflect the broader W4A16/W4A8 support.


77-85: LGTM: Operator call correctly updated for new API.

The parameter mappings are correct:

  • input=x.contiguous() (clear and explicit)
  • scales=weight_scale.type(x.dtype) (was weight_scale)
  • group_size=GROUP_SIZE (was GROUP_SIZE)
  • has_zero_point=has_zero (was has_zero)
  • output_dtype=x.dtype (new required parameter)
  • zeros=None (unchanged)

The use of keyword arguments improves clarity and maintainability.

tests/unittest/_torch/thop/test_finegrained_mixed_dtype_gemm.py (4)

11-47: Comprehensive test parameterization covers key scenarios.

The test parameters appropriately cover:

  • Various matrix dimensions (small to large)
  • Different group sizes (64, 128)
  • Multiple activation dtypes (bfloat16, float16)
  • All quantization flags combinations
  • Both W4A16 and W4A8 AWQ modes

78-82: Conditional activation type logic is correct.

The logic correctly sets activation_type to torch.float8_e4m3fn for W4A8 AWQ mode and uses the original activation_dtype for W4A16 mode.


102-114: Operator call correctly implements new API.

The finegrained_mixed_dtype_gemm call properly uses:

  • Keyword arguments for clarity
  • Conditional FP8 input conversion for W4A8 AWQ
  • Explicit output_dtype parameter
  • Optional alpha parameter for W4A8 AWQ
  • Comprehensive parameter coverage

116-122: Reference computation and validation logic is sound.

The test correctly:

  • Applies FP8 alpha scaling for W4A8 AWQ mode
  • Uses groupwise ground truth matrix multiplication
  • Validates with appropriate tolerance (2)
tests/unittest/_torch/thop/test_w4a8_linear.py (4)

41-46: W4A8-specific parameter setup is correct.

The test correctly includes W4A8-specific parameters:

  • weight_scale uses torch.float16 (appropriate for W4A8)
  • weight_scale_2 and input_scale for FP8 scaling
  • Parameter types and device placement are consistent

49-66: Linear layer configuration and weight loading is proper.

The test correctly:

  • Uses QuantAlgo.W4A8_AWQ quantization algorithm
  • Loads all required W4A8 parameters including the new scaling factors
  • Clones weight tensor to avoid modification issues

70-75: Weight preprocessing uses correct activation type.

The preprocessing correctly uses torch.float8_e4m3fn as the activation type for W4A8, and the weight comparison validates the preprocessing worked correctly.


84-98: Reference computation correctly implements W4A8 flow.

The reference implementation properly:

  • Applies pre-quantization scaling
  • Uses static quantization to FP8 E4M3
  • Computes combined alpha scaling factor
  • Calls the new GEMM operator with correct parameters
  • Adjusts weight scales by dividing by weight_scale_2
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.h (1)

27-44: LGTM! Well-structured API extension.

The class renaming and API extensions are appropriate:

  • The new name finegrainedMixedDtypeGemmRunner better reflects the mixed-dtype capabilities
  • Adding outputDtype parameter provides flexibility for different output precisions
  • The alpha parameter with default value maintains backward compatibility
tests/unittest/_torch/modules/test_fused_moe.py (4)

648-649: Good test coverage for both weight loading modes.

The parameterization properly tests both the ModelOpt checkpoint format (VANILLA) and the custom W4A8 format.


667-678: Appropriate handling of weight scale key differences.

The lookup table correctly maps weight_scale_inv for CUSTOM_W4A8 mode, reflecting the different weight scale representations between the two formats.


770-776: Correct unpacking logic for different packing strategies.

The unpacking function properly handles the different packing dimensions:

  • VANILLA: transposed unpacking for output dimension packing
  • CUSTOM_W4A8: direct unpacking and transpose for input dimension packing

878-883: Comprehensive validation checks.

Good additions:

  • NaN checks ensure numerical stability
  • Non-zero output validation confirms the computation produces meaningful results
  • Debug prints help troubleshooting
cpp/tensorrt_llm/thop/finegrained_mixed_dtype_gemm_thop.cpp (4)

47-143: Well-implemented constructor with proper dtype validation.

The constructor correctly:

  • Validates activation/output dtype combinations
  • Supports Float8_e4m3fn activation with Half/BFloat16 outputs
  • Enforces matching dtypes for Half/BFloat16 activations
  • Provides clear error messages for unsupported configurations

216-228: Correct output dtype handling.

The output tensor creation properly uses mOutputDtype instead of the activation dtype, enabling mixed-precision GEMM operations.


260-261: Appropriate alpha parameter handling.

The alpha value is correctly converted from double to float and passed to the underlying GEMM runner.


274-280: Correct Torch library registration.

The registration properly exposes the renamed class with its updated constructor signature and methods.

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (2)

678-717: Well-implemented Python wrapper for the C++ runner.

The class correctly:

  • Uses appropriate instance key for caching runners
  • Wraps the renamed C++ class
  • Handles optional alpha parameter with proper default value

720-764: Clean operator implementation with proper validation.

The new operator:

  • Has a more descriptive name reflecting its mixed-dtype capabilities
  • Properly validates zeros tensor presence when required
  • Maintains AutoTuner integration for performance optimization
  • Correctly handles all parameters including the new alpha value
tensorrt_llm/_torch/modules/linear.py (5)

50-55: LGTM! Clean implementation of tensor parallel mode toggling.

The flip method correctly switches between ROW and COLUMN parallel modes, which is necessary for sharding activation scales along the dimension orthogonal to weight sharding.


119-126: LGTM! Correct handling of both W4A16 and W4A8 quantization methods.

The code properly selects the activation dtype based on the quantization method (FP16 for W4A16, FP8 for W4A8) and applies the necessary preprocessing for mixed-precision GEMM.


902-903: LGTM! Correct weight scale tensor layout and improved GEMM operator usage.

The changes properly transpose weight scale tensors to match the expected layout and use the new finegrained_mixed_dtype_gemm operator with explicit named parameters. The use of TensorParallelMode.flip for pre_quant_scale sharding is appropriate since it applies to activations.

Also applies to: 918-931, 962-982, 998-1022


1276-1279: LGTM! Correct extension of quantization method selection.

The function properly returns W4A8_AWQ_LinearMethod when the quantization mode is int4 weight-only per-group and the algorithm is W4A8_AWQ.


1403-1408: LGTM! Consistent property implementation.

The has_w4a8_awq property correctly identifies W4A8 AWQ quantization configuration, maintaining consistency with the existing has_w4a16_awq property.

tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)

99-102: LGTM! Correct handling of CUSTOM_W4A8 weight loading mode.

The CUSTOM_W4A8 mode appropriately uses the same weight key structure as VANILLA mode with individual w1, w2, and w3 weights.


583-584: LGTM! Improved documentation of scaling factors.

The updated comments clearly explain the purpose of fc31_act_scale (reciprocal of per-channel pre_quant_scale * per-tensor input_scale) and alpha (weight_scale_2 * input_scale for rescaling GEMM output).

Also applies to: 614-615


679-710: LGTM! Clean implementation of architecture-specific weight loading.

The code correctly handles different weight formats:

  • SM89: Uses existing preprocessing logic
  • SM90 VANILLA: Unpacks int4 weights, transposes, and repacks for the expected layout
  • SM90 CUSTOM_W4A8: No-op, indicating weights are pre-formatted correctly

The extraction of packer/unpacker operations improves code clarity.

Also applies to: 725-755


759-929: LGTM! Comprehensive handling of different checkpoint formats.

The implementation correctly handles two checkpoint formats:

Custom W4A8 format:

  • Uses "weight_scale_inv" parameter name
  • Fuses per-tensor input_scale with per-channel pre_quant_scale
  • Fuses weight_scale_2 into alpha

Vanilla (ModelOpt) format:

  • Uses "weight_scale" parameter name
  • Keeps scales separate and combines them during loading
  • Divides weight scales by weight_scale_2_max before storage

The scale computations are mathematically correct for both formats.

@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from e080b5c to 636b31e Compare July 20, 2025 14:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)

583-583: Improve comment clarity for activation scale parameter.

The comment could be more precise about the mathematical operation being performed.

-        # Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale
+        # Scale activation by reciprocal of (per-channel pre_quant_scale * per-tensor input_scale)

614-614: Improve comment clarity for weight scale parameter.

Similar to the previous comment, this could be more mathematically precise.

-        # Multiply W@X with per-tensor weight_scale_2 * per-tensor input_scale.
+        # Scale matrix multiplication result by (per-tensor weight_scale_2 * per-tensor input_scale)

786-786: Fix line length violations.

Two lines exceed the 120-character limit as flagged by static analysis.

-            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
+            # In vanilla ckpt (ModelOpt format), per-tensor input_scale and 
+            # per-channel pre_quant_scale are separately stored
-            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
+            # In vanilla ckpt (ModelOpt format), per-tensor input_scale and 
+            # per-channel pre_quant_scale are separately stored

Also applies to: 880-880

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e080b5c and 636b31e.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py

786-786: Line too long (132 > 120)

(E501)


880-880: Line too long (132 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (18)
tests/unittest/_torch/modules/test_fused_moe.py (9)

24-24: LGTM: Import addition is appropriate.

The MoEWeightLoadingMode import aligns with the new parameterized test functionality.


648-649: LGTM: Test parameterization correctly implemented.

The test function signature update and parameterization with both VANILLA and CUSTOM_W4A8 modes provides comprehensive coverage for the new weight loading functionality.


667-677: LGTM: Dynamic weight attribute lookup table is well-designed.

The lookup table approach elegantly handles the different weight scale attribute names between the two modes, making the test code more maintainable.


681-692: LGTM: Weight shape logic correctly differentiates packing strategies.

The conditional weight shapes properly reflect the different packing strategies:

  • VANILLA: Packs 4-bit weights along output dimension
  • CUSTOM_W4A8: Packs 4-bit weights along input dimension

This matches the implementation differences described in the PR summary.


694-697: LGTM: Weight initialization uses appropriate data type and range.

Using torch.int8 with range [-128, 127] for the quantized weights is correct for 4-bit packed representations.


752-753: LGTM: Model configuration correctly includes weight loading mode.

The weight_loading_mode parameter is properly passed to the CutlassFusedMoE constructor, ensuring the test validates the actual runtime configuration.


770-778: LGTM: Weight unpacking logic correctly handles both modes.

The unpack_weights helper function properly handles the different tensor layouts:

  • VANILLA: Transpose then unpack
  • CUSTOM_W4A8: Unpack then transpose

This correctly inverts the packing strategies used during weight preparation.


810-839: LGTM: Quantization process helper function is well-structured.

The process_layer function encapsulates the complex quantization logic clearly, with proper handling of optional scale parameters that are only used in VANILLA mode.


874-883: LGTM: Enhanced output validation with comprehensive checks.

The additional validation checks are valuable:

  • NaN detection in both outputs
  • Non-empty result verification
  • Debug output for troubleshooting

This significantly improves test reliability and debugging capabilities.

tensorrt_llm/_torch/modules/fused_moe/quantization.py (9)

99-102: LGTM: Conditional check correctly includes new weight loading mode.

The addition of MoEWeightLoadingMode.CUSTOM_W4A8 to the existing conditional maintains backward compatibility while enabling the new functionality.


679-681: LGTM: Pack/unpack operations are properly defined.

The tensor pack/unpack operations are correctly assigned and will be used consistently throughout the weight loading process.


694-709: LGTM: SM90 weight preprocessing logic is well-structured.

The conditional logic properly handles the different weight formats:

  • VANILLA: ModelOpt format requiring transpose/unpack/transpose/pack sequence
  • CUSTOM_W4A8: Custom format requiring no preprocessing

The comments clearly explain the tensor shape transformations.


739-754: LGTM: Consistent weight preprocessing for w2 weights.

The w2 weight preprocessing follows the same pattern as w3_w1 weights, maintaining consistency in the implementation.


759-764: LGTM: Scale naming abstraction is clean.

The dynamic scale name selection based on weight loading mode creates a clean abstraction that reduces code duplication in the scale loading logic.


850-851: LGTM: Scale division for VANILLA mode is properly conditioned.

The weight scale division by weight_scale_2 is correctly applied only in VANILLA mode, maintaining the different scaling semantics between the two checkpoint formats.


871-902: LGTM: FC2 scale computation mirrors FC31 logic.

The fc2 scale computation correctly mirrors the fc31 (w3_w1) scale computation pattern, with appropriate handling of per-channel vs per-tensor scales for both modes.


918-919: LGTM: Consistent scale division pattern for w2 weights.

The w2 scale division follows the same pattern as w3_w1 scales, maintaining consistency in the VANILLA mode scale handling.


778-827: Verified VANILLA mode scale computation is correct

I’ve compared the code in quantization.py against the reference logic in tests/unittest/_torch/modules/test_fused_moe.py. In both implementations you:

  • Compute pre_quant_scale_max = max(a1, a3) and input_scale_max = max(p3, p1).
  • Compute weight_scale_2_max = max(q3, q1).
  • In production code:
    • fc31_act_scale = pre_quant_scale_max * (1 / input_scale_max)
    • fc31_alpha = weight_scale_2_max * input_scale_max
  • In the test’s process_layer, this corresponds exactly to multiplying activations by pre_quant_scale then dividing by input_scale, and later multiplying outputs by input_scale and weight_scale_2.

No discrepancies found—no changes required here.

@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from 636b31e to 42da122 Compare July 20, 2025 14:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/unittest/_torch/modules/test_fused_moe.py (2)

810-839: Consider cleaning up debug prints.

The process_layer function correctly handles the quantization logic with conditional scaling. However, consider removing the commented debug print statements for production code.

-                # print(f"DEBUG: {layer_name}.max() {abs(act.max())}")
-                # print(f"DEBUG: {layer_name}_pre_quant.max() {abs(act.max())}")
-                # print(f"DEBUG: {layer_name}_quantized.max() {abs(act.max())}")
-                # print(f"DEBUG: {layer_name}_output_input_scale.max() {abs(output.max())}")
-                # print(f"DEBUG: {layer_name}_output_weight_scale_2.max() {abs(output.max())}")

874-883: Improve validation approach and consider output verbosity.

The enhanced validation with NaN checks and non-empty assertions is valuable, but consider these improvements:

  1. The debug prints will clutter test output - consider conditional printing based on test failure
  2. Consider keeping torch.testing.assert_close as the primary assertion method for better error messages
-    success = torch.allclose(output, ref_output, rtol=1e-2, atol=0.1)
-    print(f"ref_output: {ref_output}")
-    print(f"output: {output}")
-    # assert that result does not contain NaN
     ref_has_nan = torch.isnan(ref_output).any()
     out_has_nan = torch.isnan(output).any()
     assert not ref_has_nan, "ref_output contains NaN"
     assert not out_has_nan, "output contains NaN"
     assert torch.nonzero(output).numel() != 0 and torch.nonzero(ref_output).numel() != 0
-    assert success
+    
+    # Use torch.testing.assert_close for better error messages
+    try:
+        torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
+    except AssertionError:
+        print(f"ref_output: {ref_output}")
+        print(f"output: {output}")
+        raise
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 636b31e and 42da122.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py

786-786: Line too long (132 > 120)

(E501)


880-880: Line too long (132 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (17)
tests/unittest/_torch/modules/test_fused_moe.py (7)

24-24: LGTM: Import addition is correct.

The MoEWeightLoadingMode import is necessary for the new parameterized test functionality.


648-649: LGTM: Parametrization correctly adds support for both weight loading modes.

The test now properly covers both VANILLA and CUSTOM_W4A8 modes as intended for the ModelOpt checkpoint format support.


667-677: LGTM: Lookup table provides clean abstraction for mode-specific keys.

The dynamic key selection based on weight_loading_mode is well-implemented and maintains code readability.


694-743: LGTM: Weight initialization correctly handles both modes.

The weight initialization logic properly uses the lookup table and creates appropriately shaped tensors for each mode. The code is clean and maintainable.


752-753: LGTM: Constructor parameter addition is correct.

The weight_loading_mode parameter is properly passed to the CutlassFusedMoE constructor.


681-692: Weight shape logic consistent with quantization implementation

All inspected shape definitions in tests/unittest/_torch/modules/test_fused_moe.py align with the packing logic in tensorrt_llm/_torch/modules/fused_moe/quantization.py for both VANILLA and CUSTOM_W4A8 modes. No discrepancies found—no further changes needed.


770-776: Unpacking logic verified

The unpack_weights helper in the test correctly mirrors the operator usage in the quantization module:

  • VANILLA: both the implementation (SM90 branch) and test perform a transpose before unpacking.
  • CUSTOM_W4A8: both the SM89 branch and test unpack first and then transpose.

No changes needed here.

tensorrt_llm/_torch/modules/fused_moe/quantization.py (10)

99-102: LGTM: Correctly extends weight loading mode support.

The addition of MoEWeightLoadingMode.CUSTOM_W4A8 to the conditional check properly extends support while maintaining existing functionality.


583-593: LGTM: New activation scale parameters are well-defined.

The fc31_act_scale and fc2_act_scale parameters are properly registered with appropriate shapes and dtypes. The comments clearly explain their purpose.


614-625: LGTM: Alpha parameters are correctly defined.

The fc31_alpha and fc2_alpha parameters are properly registered with appropriate dtypes and clear documentation of their purpose.


738-754: LGTM: Consistent weight loading logic for w2 weights.

The load_expert_w2_weight method correctly implements the same processing pattern as the w3_w1_weight method, ensuring consistency across weight types.


759-764: LGTM: Clean approach to handle different scale key names.

The conditional selection of weight_scale_name based on the loading mode provides a clean abstraction for the different key naming conventions.


871-902: LGTM: Consistent fc2 scale handling.

The fc2 scale loading follows the same pattern as fc31, maintaining consistency in the implementation approach across different layer types.


850-851: LGTM: Conditional weight scale normalization is correct.

The weight scale normalization is correctly applied only in VANILLA mode, consistent with the different scale handling approaches between modes.


918-919: LGTM: Consistent normalization logic for fc2 weights.

The fc2 weight scale normalization follows the same conditional pattern, maintaining consistency with the fc31 handling.


693-709: Verify SM90 weight processing sequence

The SM90 branches for VANILLA (transpose→unpack→transpose→pack) and CUSTOM_W4A8 (pass-through) appear consistent with the ModelOpt quantization strategy, but we currently have no automated validation for SM90. Please:

  • Confirm that, for both w2 and w31 shards in
    tensorrt_llm/_torch/modules/fused_moe/quantization.py (lines ~693–709), the sequence
    transpose → unpack → transpose → pack yields the correct final layout and dtype.
  • Add unit tests in tests/unittest/_torch/modules/test_fused_moe.py covering:
    • SM90 + MoEWeightLoadingMode.VANILLA produces the expected shapes/data.
    • SM90 + MoEWeightLoadingMode.CUSTOM_W4A8 retains the original shard unchanged.

778-827: Scale computation logic confirmed correct

After cross-referencing the fc31 logic with the patterns in linear.py, quantization_utils.py, and the unit tests in test_fused_moe.py, the two branches correctly implement:

  • CUSTOM_W4A8:
    • fc31_act_scale = 1 / input_scale_max
    • fc31_alpha = input_scale_max
  • VANILLA:
    • fc31_act_scale = pre_quant_scale_max / input_scale_max
    • fc31_alpha = weight_scale_2_max * input_scale_max

These match the intended quantization formulas (act_scale = pre_quant_scale·input_scale⁻¹, and alpha = input_scale·weight_scale_2) used elsewhere in the codebase and tests. No changes required.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_moe.py (1)

874-883: Enhanced output validation with comprehensive checks.

The validation improvements are excellent:

  • Manual allclose check allows for additional custom validations
  • NaN checks help detect numerical issues in the quantization pipeline
  • Non-empty assertions ensure meaningful output data
  • Comprehensive error detection for debugging

Consider making the debug prints conditional to avoid cluttering test output:

-    print(f"ref_output: {ref_output}")
-    print(f"output: {output}")
+    if not success:
+        print(f"ref_output: {ref_output}")
+        print(f"output: {output}")
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 42da122 and 0872046.

📒 Files selected for processing (1)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/modules/test_fused_moe.py (8)

24-24: LGTM: Import addition supports new test functionality.

The MoEWeightLoadingMode import is correctly added to support the new test parameterization for different weight loading modes.


648-649: LGTM: Test parameterization correctly covers both weight loading modes.

The addition of weight_loading_mode parameter ensures comprehensive testing of both VANILLA and CUSTOM_W4A8 modes, aligning with the PR's goal to support ModelOpt checkpoint format.


667-677: LGTM: Clean lookup table implementation for mode-specific weight keys.

The LUT approach elegantly handles the different naming conventions between VANILLA and CUSTOM_W4A8 modes, particularly for the weight scale keys where CUSTOM_W4A8 uses "weight_scale_inv" instead of "weight_scale".


681-692: LGTM: Correct implementation of mode-specific weight packing strategies.

The weight shape logic correctly handles the different packing approaches:

  • VANILLA mode packs 4-bit weight pairs along the output dimension
  • CUSTOM_W4A8 mode packs 4-bit weight pairs along the input dimension

This properly reflects the differences between ModelOpt and custom quantization script formats.


693-726: LGTM: Proper initialization of quantized weights and scaling tensors.

The weight initialization correctly:

  • Uses int8 tensors for quantized weights with valid range [-128, 127]
  • Introduces necessary scaling tensors (pre_quant_scale, weight_scale_2, input_scale)
  • Sets appropriate data types and devices for all tensors
  • Uses reasonable scale ranges for testing purposes

727-742: LGTM: Weight dictionary correctly populated using LUT.

The weight dictionary setup properly uses the lookup table to ensure correct attribute names for each weight loading mode, maintaining compatibility with both VANILLA and CUSTOM_W4A8 formats.


752-753: LGTM: Model configuration correctly passes weight loading mode.

The addition of weight_loading_mode parameter to the CutlassFusedMoE constructor ensures the model uses the appropriate weight loading and processing logic for the selected mode.


770-860: LGTM: Comprehensive reference implementation supporting both weight loading modes.

The reference implementation correctly handles both modes:

  • unpack_weights properly unpacks tensors based on the packing strategy
  • process_layer encapsulates the quantization pipeline with conditional scaling
  • Pre-quantization and weight_scale_2 are correctly applied only in VANILLA mode
  • The forward pass logic accurately reflects the expected behavior for both modes

@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from 0872046 to d8244f8 Compare July 20, 2025 14:52
@rosenrodt
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12386 [ run ] triggered by Bot

@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from d8244f8 to e1cc336 Compare July 20, 2025 15:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/modules/test_fused_moe.py (1)

774-899: Comprehensive reference implementation with room for improvement.

The reference implementation correctly handles both weight loading modes with appropriate unpacking strategies and conditional scale applications. The enhanced validation logic provides good debugging capabilities.

Minor suggestions for improvement:

-    print(f"ref_output: {ref_output}")
-    print(f"output: {output}")
+    # Debug prints can be enabled when needed
+    # print(f"ref_output: {ref_output}")
+    # print(f"output: {output}")

Consider simplifying the assertion logic:

-    success = torch.allclose(output, ref_output, rtol=1e-2, atol=0.1)
-    # assert that result does not contain NaN
-    ref_has_nan = torch.isnan(ref_output).any()
-    out_has_nan = torch.isnan(output).any()
-    assert not ref_has_nan, "ref_output contains NaN"
-    assert not out_has_nan, "output contains NaN"
-    assert torch.nonzero(output).numel() != 0 and torch.nonzero(
-        ref_output).numel() != 0
-    assert success
+    # Comprehensive validation
+    assert not torch.isnan(ref_output).any(), "ref_output contains NaN"
+    assert not torch.isnan(output).any(), "output contains NaN"
+    assert torch.nonzero(output).numel() > 0, "output is empty"
+    assert torch.nonzero(ref_output).numel() > 0, "ref_output is empty"
+    torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d8244f8 and e1cc336.

📒 Files selected for processing (1)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (6)
tests/unittest/_torch/modules/test_fused_moe.py (6)

21-23: LGTM! Import addition is correct.

The import of MoEWeightLoadingMode is properly added and necessary for the new test parameter functionality.


645-648: Good test coverage enhancement.

The parameterization of weight_loading_mode with both VANILLA and CUSTOM_W4A8 modes ensures comprehensive testing coverage for the new weight loading functionality.


666-678: Clean approach to handle different weight attribute naming.

The lookup table pattern effectively abstracts the differences in weight attribute names between the two loading modes, making the code more maintainable.


682-693: Correct implementation of different weight packing strategies.

The weight shape handling correctly accounts for the different packing approaches:

  • VANILLA mode: Output dimension packing (ModelOpt W4A8)
  • CUSTOM_W4A8 mode: Input dimension packing (custom quantization script)

The clear comments explain the rationale behind each approach.


694-746: Proper weight initialization with lookup table usage.

The weight initialization correctly:

  • Uses the lookup table for consistent attribute naming across modes
  • Includes all necessary tensors (pre_quant_scale, weight_scale_2) for comprehensive testing
  • Maintains proper tensor shapes and device placement

756-757: Correct parameter propagation to model.

The weight_loading_mode parameter is properly passed to the CutlassFusedMoE constructor, ensuring the test configuration matches the intended behavior.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12386 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9205 completed with status: 'FAILURE'

@rosenrodt
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12395 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/unittest/_torch/modules/test_fused_moe.py (1)

891-899: Consider removing debug prints for production.

The defensive checks for NaN values and empty tensors are excellent additions that prevent silent failures. However, consider removing the debug print statements (lines 891-892) before merging to avoid cluttering the test output.

-    print(f"ref_output: {ref_output}")
-    print(f"output: {output}")
tensorrt_llm/_torch/modules/fused_moe/quantization.py (2)

778-828: Address line length and verify scale computation logic

The fc31 activation scale computation logic correctly handles the differences between CUSTOM_W4A8 and VANILLA modes:

  • CUSTOM_W4A8: Fused scales, simpler computation
  • VANILLA: Separate scales requiring more complex handling

However, there are formatting issues to address.

Apply this diff to fix line length issues:

-            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
+            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and 
+            # per-channel pre_quant_scale are separately stored

871-903: Address line length and verify fc2 scale logic

The fc2 activation scale computation follows the same pattern as fc31, correctly handling the mode differences. The logic appears sound but has formatting issues.

Apply this diff to fix line length issues:

-            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
+            # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and 
+            # per-channel pre_quant_scale are separately stored
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ebbbb6f and e02c34c.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (2 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
  • tensorrt_llm/_torch/modules/linear.py (4 hunks)
  • tests/unittest/_torch/modules/test_fused_moe.py (7 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/modules/fused_moe/interface.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/fused_moe/quantization.py

786-786: Line too long (132 > 120)

(E501)


880-880: Line too long (132 > 120)

(E501)

🔇 Additional comments (18)
tensorrt_llm/_torch/modules/linear.py (4)

964-966: LGTM! Improved device handling for multi-GPU setups.

The dynamic device assignment using module.weight.device instead of hardcoded CUDA device is a good improvement. The comments clearly explain the rationale for this change.


1146-1148: LGTM! Consistent device handling improvement.

The dynamic device assignment is consistent with the W4A16 implementation and properly handles multi-GPU scenarios.


1211-1213: LGTM! Consistent dynamic device usage.

The device assignment follows the same pattern as other methods, and line 1220 correctly uses the dynamic device variable instead of hardcoded CUDA device.

Also applies to: 1220-1220


1255-1257: LGTM! Completes the consistent device handling pattern.

The final set of changes maintains consistency with the other methods in the class, properly using dynamic device detection throughout the W4A8_AWQ_LinearMethod class.

Also applies to: 1264-1264

tests/unittest/_torch/modules/test_fused_moe.py (6)

21-26: LGTM! Proper import addition.

The addition of MoEWeightLoadingMode import is necessary for the new test functionality, and the isort comments indicate proper import organization.


648-651: LGTM! Comprehensive test coverage for both weight loading modes.

The parametrization properly tests both VANILLA and CUSTOM_W4A8 modes, ensuring comprehensive coverage of the new functionality.


669-681: LGTM! Clean lookup table implementation.

The lookup table provides a maintainable way to handle different weight attribute names between the two modes, making the test code more organized and readable.


684-696: LGTM! Correct weight shape handling for different modes.

The conditional logic properly handles the different weight packing strategies:

  • VANILLA mode packs pairs of 4-bit weights in the output dimension
  • CUSTOM_W4A8 mode packs in the input dimension

The comments clearly explain the rationale for each approach.


697-749: LGTM! Proper weight initialization for W4A8 quantization.

The weight initialization correctly:

  • Uses torch.randint for quantized int8 weights
  • Adds all necessary scale tensors (pre_quant_scale, input_scale, weight_scale_2)
  • Uses the lookup table for consistent key naming
  • Provides proper tensor shapes and data types for W4A8 quantization

759-760: LGTM! Comprehensive reference implementation for both modes.

The reference implementation correctly:

  • Passes weight_loading_mode to the model
  • Handles different weight unpacking strategies via unpack_weights
  • Conditionally applies pre_quant_scale and weight_scale_2 only for VANILLA mode
  • Encapsulates complex quantization logic in the process_layer helper function

The implementation properly simulates the quantization process for both modes, providing a solid reference for comparison.

Also applies to: 777-876

tensorrt_llm/_torch/modules/fused_moe/quantization.py (8)

99-102: LGTM: Properly extends weight loading mode support

The addition of CUSTOM_W4A8 to the conditional check correctly extends support for the new weight loading mode while maintaining existing functionality.


583-594: LGTM: Appropriate parameter registration for W4A8 quantization

The new activation scale parameters are correctly registered with proper shapes and data types for the W4A8 quantization method. The comments clearly explain their purpose.


614-626: LGTM: Alpha parameters properly configured

The alpha parameters are correctly registered with appropriate shapes and dtype (float32) for storing scaling factors in the W4A8 quantization pipeline.


725-755: Consistent weight preprocessing for w2 weights

The w2 weight preprocessing logic mirrors the w3_w1 implementation with appropriate handling for different SM versions and weight loading modes. The structure is consistent and correct.


759-764: LGTM: Clean conditional weight scale naming

The conditional assignment of weight_scale_name based on the weight loading mode is clean and follows the expected naming convention differences between VANILLA and CUSTOM_W4A8 modes.


918-920: Consistent normalization for fc2 weights

The fc2 weight scale normalization follows the same pattern as fc31, only applied in VANILLA mode. This consistency is good and the logic appears correct.


679-710: Verify CUSTOM_W4A8 weight preprocessing

I couldn’t find any documentation or code that prepares weights for the CUSTOM_W4A8 mode. Please confirm that:

  • Weights loaded under MoEWeightLoadingMode.CUSTOM_W4A8 are already in the expected packed‐int4 format before this method.
  • If they aren’t, add the equivalent pack/unpack preprocessing used for the other modes.

850-852: Division for VANILLA mode is mathematically sound

  • all_w3_w1_weight_scale_2_max is computed (lines 821–826) as the element-wise maximum across all per-expert weight_scale_2 shards.
  • In VANILLA mode, dividing each per-group weight scale by this tensor mirrors the pattern in the core Linear modules (weight_scale / weight_scale_2) and the fp4_global_scale utility.
  • Unit tests in tests/unittest/_torch/modules/test_fused_moe.py (lines 821–823) explicitly build q3_q1 = torch.max(q3, q1) and verify that weight /= weight_scale_2, confirming correct behavior.

No changes required.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12395 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9214 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Collaborator

@Barry-Delaney Barry-Delaney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left some minor comments. Thanks for the contribution!

@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from e02c34c to 44fed18 Compare August 7, 2025 05:41
@rosenrodt rosenrodt requested review from a team as code owners August 7, 2025 05:41
@rosenrodt rosenrodt requested a review from hlu1 August 7, 2025 05:41
@tensorrt-cicd
Copy link
Collaborator

PR_Github #14762 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14752 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14762 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11144 completed with status: 'FAILURE'

@rosenrodt
Copy link
Collaborator Author

rosenrodt commented Aug 11, 2025

Test passed except for 1 unrelated flaky (probably) test Test / A10-PyTorch-1 / A10-PyTorch-1.disaggregated.test_disaggregated.test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0]

[2025-08-11T08:55:15.450Z] RuntimeError: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaMalloc(ptr, n): out of memory (../tensorrt_llm/runtime/tllmBuffers.h:93)
[2025-08-11T08:55:15.450Z] 1       0x7f22617fb4ab void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 139
[2025-08-11T08:55:15.450Z] 2       0x7f22619ce8f9 tensorrt_llm::runtime::BufferManager::gpuSync(nvinfer1::Dims64, nvinfer1::DataType) + 473
...
[2025-08-11T08:55:15.451Z] RuntimeError: Executor creation failed due to insufficient GPU memory.

@rosenrodt
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14802 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14802 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11175 completed with status: 'SUCCESS'

@rosenrodt
Copy link
Collaborator Author

Pinging @NVIDIA/trt-llm-torch-modules and affected file code owner @hlu1 for review

Signed-off-by: Anthony Chang <[email protected]>

fix activation pre_quant_scale sharding in w4a8/w4a16 linear

Signed-off-by: Anthony Chang <[email protected]>

address feedback; tidy

Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
Signed-off-by: Anthony Chang <[email protected]>
@rosenrodt rosenrodt force-pushed the w4a8-moe-hopper-pyt branch from 3724212 to 3b2e9bc Compare August 12, 2025 13:36
@rosenrodt
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14977 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14977 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11308 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@chzblych chzblych enabled auto-merge (squash) August 13, 2025 02:37
@chzblych chzblych merged commit 2198587 into NVIDIA:main Aug 13, 2025
4 checks passed
@Nekofish-L
Copy link
Contributor

after applying these changes, DeepSeek-R1-W4AFP8 fails to initialize during runtime

@rosenrodt
Copy link
Collaborator Author

Hi @Nekofish-L, thanks for reporting it. Could you try the fix #7123 and see if it works, while we try to get PR merged?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants