@@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
132132 renormalize : bool ,
133133 num_expert_group : int = 0 ,
134134 topk_group : int = 0 ,
135+ routed_scaling_factor : float = 1.0 ,
135136 scoring_func : str = "sigmoid" ,
136137 e_score_correction_bias : Optional [torch .Tensor ] = None ):
137138
@@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
163164 score_mask = group_mask .unsqueeze (- 1 ).expand (
164165 num_token , num_expert_group ,
165166 scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
166- tmp_scores = scores .masked_fill (~ score_mask .bool (),
167- float ("-inf" )) # [n, e]
167+ tmp_scores = scores .masked_fill (~ score_mask .bool (), 0.0 )
168+ # float("-inf")) # [n, e]
168169
169170 if e_score_correction_bias is not None :
170171 topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
@@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
176177 dim = - 1 ,
177178 sorted = False )
178179
179- if renormalize :
180- topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
181-
180+ if topk > 1 and renormalize :
181+ denominator = topk_weights .sum (dim = - 1 , keepdim = True ) + 1e-20
182+ topk_weights = topk_weights / denominator
183+ topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
182184 return topk_ids .to (torch .long ), topk_weights .to (torch .float32 )
183185
184186class KMoEGateDeepSeekV3 (BaseInjectedModule , KMoEGateBase ):
@@ -204,6 +206,7 @@ def __init__(
204206 self .is_windows = os .name == 'nt'
205207 self .use_quant = use_quant
206208 if not self .is_windows and use_quant :
209+ print ("injecting gate_linear" )
207210 self .gate_linear = nn .Linear (self .gating_dim , self .n_routed_experts , device = generate_device )
208211 self .gate_linear = KTransformersLinear (key + ".ffn_gate_inp" ,
209212 gguf_loader , config , self .gate_linear , #orig_module
@@ -219,14 +222,13 @@ def forward(self, hidden_states) -> torch.Tensor:
219222 ### compute gating score
220223 hidden_states = hidden_states .view (- 1 , h )
221224 if self .use_quant :
222- logits = self .gate_linear .forward (logits )
225+ logits = self .gate_linear .forward (hidden_states )
223226 else :
224227 logits = F .linear (
225228 hidden_states .type (torch .float32 ), self .weight .type (torch .float32 ), None
226229 )
227-
228- return grouped_topk (hidden_states , logits , self .top_k , self .norm_topk_prob ,
229- self .n_group , self .topk_group , "sigmoid" , self .e_score_correction_bias )
230+ return grouped_topk (hidden_states , logits , self .top_k , self .norm_topk_prob , self .n_group ,
231+ self .topk_group , self .routed_scaling_factor , "sigmoid" , self .e_score_correction_bias )
230232
231233 def load (self , w : dict | nn .Parameter | tuple | None = None , device : str | None = None ):
232234 if device is None : device = self .device
0 commit comments