Skip to content

Commit 616d5f2

Browse files
committed
Update unit test of fused moe w4afp8
Signed-off-by: Min Yu <[email protected]>
1 parent 894efc1 commit 616d5f2

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,15 +1328,12 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
13281328
dtype=torch.int8).cuda()
13291329

13301330
# The pre-quant scale to be multiplied with the input activation.
1331-
w1_pre_quant_scale = torch.ones(HIDDEN_SIZE,
1332-
dtype=dtype,
1333-
device="cuda")
1334-
w2_pre_quant_scale = torch.ones(INTERMEDIATE_SIZE,
1335-
dtype=dtype,
1336-
device="cuda")
1337-
w3_pre_quant_scale = torch.ones(HIDDEN_SIZE,
1338-
dtype=dtype,
1339-
device="cuda")
1331+
w1_pre_quant_scale = torch.rand(
1332+
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
1333+
w2_pre_quant_scale = torch.rand(
1334+
INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
1335+
w3_pre_quant_scale = torch.rand(
1336+
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
13401337

13411338
# The weight scale to dequantize int4 weights (by multiplication).
13421339
w1_scale = torch.randn(

0 commit comments

Comments
 (0)