File tree Expand file tree Collapse file tree 1 file changed +8
-9
lines changed
tests/unittest/_torch/modules Expand file tree Collapse file tree 1 file changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -1484,15 +1484,14 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
14841484 dtype = torch .int8 ).cuda ()
14851485
14861486 # The pre-quant scale to be multiplied with the input activation.
1487- w1_pre_quant_scale = torch .ones (HIDDEN_SIZE ,
1488- dtype = dtype ,
1489- device = "cuda" )
1490- w2_pre_quant_scale = torch .ones (INTERMEDIATE_SIZE ,
1491- dtype = dtype ,
1492- device = "cuda" )
1493- w3_pre_quant_scale = torch .ones (HIDDEN_SIZE ,
1494- dtype = dtype ,
1495- device = "cuda" )
1487+ # Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles
1488+ # non-uniform pre-quant scaling factors correctly
1489+ w1_pre_quant_scale = torch .rand (
1490+ HIDDEN_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
1491+ w2_pre_quant_scale = torch .rand (
1492+ INTERMEDIATE_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
1493+ w3_pre_quant_scale = torch .rand (
1494+ HIDDEN_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
14961495
14971496 # The weight scale to dequantize int4 weights (by multiplication).
14981497 w1_scale = torch .randn (
You can’t perform that action at this time.
0 commit comments