Skip to content

Commit

Permalink
dipu rms norm init: add_unit_offset param
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Aug 30, 2024
1 parent 5f75320 commit 031ea35
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def backward(ctx, grad_output):

class MixedFusedRMSNorm(torch.nn.Module):

def __init__(self, normalized_shape, eps=1e-5):
def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False):
# TODO: Further optimization when there are device and dtype available.
# factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
Expand All @@ -105,6 +105,8 @@ def __init__(self, normalized_shape, eps=1e-5):
self.normalized_shape = torch.Size(normalized_shape)
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
self.eps = eps
self.add_unit_offset = add_unit_offset
self.reset_parameters()

def forward(self, hidden_states):
return _MixedFusedRMSNormFunction.apply(
Expand All @@ -115,7 +117,10 @@ def forward(self, hidden_states):
)

def reset_parameters(self):
init.ones_(self.weight)
if self.add_unit_offset:
init.zeros_(self.weight)
else:
init.ones_(self.weight)

def extra_repr(self):
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)

0 comments on commit 031ea35

Please sign in to comment.