Skip to content

Commit 8ac2332

Browse files
committed
Update unit test of fused moe w4afp8
Signed-off-by: Min Yu <[email protected]>
1 parent ac73e31 commit 8ac2332

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,15 +1484,14 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
14841484
dtype=torch.int8).cuda()
14851485

14861486
# The pre-quant scale to be multiplied with the input activation.
1487-
w1_pre_quant_scale = torch.ones(HIDDEN_SIZE,
1488-
dtype=dtype,
1489-
device="cuda")
1490-
w2_pre_quant_scale = torch.ones(INTERMEDIATE_SIZE,
1491-
dtype=dtype,
1492-
device="cuda")
1493-
w3_pre_quant_scale = torch.ones(HIDDEN_SIZE,
1494-
dtype=dtype,
1495-
device="cuda")
1487+
# Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles
1488+
# non-uniform pre-quant scaling factors correctly
1489+
w1_pre_quant_scale = torch.rand(
1490+
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
1491+
w2_pre_quant_scale = torch.rand(
1492+
INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
1493+
w3_pre_quant_scale = torch.rand(
1494+
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
14961495

14971496
# The weight scale to dequantize int4 weights (by multiplication).
14981497
w1_scale = torch.randn(

0 commit comments

Comments
 (0)