File tree Expand file tree Collapse file tree 1 file changed +6
-9
lines changed
tests/unittest/_torch/modules Expand file tree Collapse file tree 1 file changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -1328,15 +1328,12 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
13281328 dtype = torch .int8 ).cuda ()
13291329
13301330 # The pre-quant scale to be multiplied with the input activation.
1331- w1_pre_quant_scale = torch .ones (HIDDEN_SIZE ,
1332- dtype = dtype ,
1333- device = "cuda" )
1334- w2_pre_quant_scale = torch .ones (INTERMEDIATE_SIZE ,
1335- dtype = dtype ,
1336- device = "cuda" )
1337- w3_pre_quant_scale = torch .ones (HIDDEN_SIZE ,
1338- dtype = dtype ,
1339- device = "cuda" )
1331+ w1_pre_quant_scale = torch .rand (
1332+ HIDDEN_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
1333+ w2_pre_quant_scale = torch .rand (
1334+ INTERMEDIATE_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
1335+ w3_pre_quant_scale = torch .rand (
1336+ HIDDEN_SIZE , dtype = dtype , device = "cuda" ) * 0.1 + 0.95
13401337
13411338 # The weight scale to dequantize int4 weights (by multiplication).
13421339 w1_scale = torch .randn (
You can’t perform that action at this time.
0 commit comments