diff --git a/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py b/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py index a99ee29..3782e85 100644 --- a/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py +++ b/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py @@ -95,7 +95,7 @@ def backward(ctx, grad_output): class MixedFusedRMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): # TODO: Further optimization when there are device and dtype available. # factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {} @@ -105,6 +105,8 @@ def __init__(self, normalized_shape, eps=1e-5): self.normalized_shape = torch.Size(normalized_shape) self.weight = torch.nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) self.eps = eps + self.add_unit_offset = add_unit_offset + self.reset_parameters() def forward(self, hidden_states): return _MixedFusedRMSNormFunction.apply( @@ -115,7 +117,10 @@ def forward(self, hidden_states): ) def reset_parameters(self): - init.ones_(self.weight) + if self.add_unit_offset: + init.zeros_(self.weight) + else: + init.ones_(self.weight) def extra_repr(self): return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)