diff --git a/rankers/train/loss/__init__.py b/rankers/train/loss/__init__.py index 2837b32..0ef6f3e 100644 --- a/rankers/train/loss/__init__.py +++ b/rankers/train/loss/__init__.py @@ -160,6 +160,14 @@ def step_d(self): else: self.t += 1 self.d_weight = self.d_weight * (self.t / self.T) ** 2 + + def step(self): + if self.t >= self.T: + pass + else: + self.t += 1 + self.q_weight = self.q_weight * (self.t / self.T) ** 2 + self.d_weight = self.d_weight * (self.t / self.T) ** 2 @staticmethod def reg(reg, weight=0):