diff --git a/src/contrastors/config.py b/src/contrastors/config.py index bdebbe0..e6ff6de 100644 --- a/src/contrastors/config.py +++ b/src/contrastors/config.py @@ -13,6 +13,7 @@ class TrainArgs(BaseModel): warmup_steps: Optional[int] = None warmup_pct: Optional[float] = None cooldown_steps: Optional[int] = None + cooldown_pct: Optional[float] = None checkpoint: Optional[str] = None wandb: bool wandb_project_name: str diff --git a/src/contrastors/models/encoder/modeling_nomic_bert.py b/src/contrastors/models/encoder/modeling_nomic_bert.py index 434c13f..0336f72 100644 --- a/src/contrastors/models/encoder/modeling_nomic_bert.py +++ b/src/contrastors/models/encoder/modeling_nomic_bert.py @@ -190,7 +190,11 @@ def forward( residual = None batch, seqlen = hidden_states.shape[:2] - hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(hidden_states, attention_mask) + try: + hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(hidden_states, attention_mask) + except: + # probably on a newer flash-attention version, which now outputs the amt of tokens used by the batch, which we don't need so just drop it + hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(hidden_states, attention_mask) for i, layer in enumerate(self.layers): if self.gradient_checkpointing and self.training: diff --git a/src/contrastors/trainers/base.py b/src/contrastors/trainers/base.py index d853df3..6ab77e8 100644 --- a/src/contrastors/trainers/base.py +++ b/src/contrastors/trainers/base.py @@ -211,20 +211,27 @@ def get_optimizer(self, config, ds_config=None): return optimizer def get_scheduler(self, config, optimizer, ds_config): + total_num_steps = self.total_num_steps * config.num_epochs + if hasattr(config, "warmup_steps") and getattr(config, "warmup_steps") is not None: - total_num_steps = self.total_num_steps * config.num_epochs warmup_steps = config.warmup_steps - elif hasattr(config, "warmup_pct") and getattr(config, "warmup_pct") is not None: - total_num_steps = self.total_num_steps * config.num_epochs warmup_steps = int(total_num_steps * config.warmup_pct) - else: warmup_steps = 0 + + if hasattr(config, "cooldown_steps") and getattr(config, "cooldown_steps") is not None: + cooldown_steps = config.cooldown_steps + elif hasattr(config, "cooldown_pct") and getattr(config, "cooldown_pct") is not None: + cooldown_steps = int(total_num_steps * config.cooldown_pct) + else: + cooldown_steps = 0 + self.print("*" * 50 + " SCHEDULER " + "*" * 50) self.print(f"Using {config.schedule_type} learning rate schedule") self.print(f"Warmup steps: {warmup_steps}") + self.print(f"Cooldown steps: {cooldown_steps}") self.print(f"Total num steps: {total_num_steps}") if ds_config: @@ -240,11 +247,20 @@ def get_scheduler(self, config, optimizer, ds_config): scheduler["params"]["total_num_steps"] = total_num_steps return None + if config.schedule_type == "warmup_stable_decay": + scheduler_specific_kwargs = { + "num_stable_steps": total_num_steps - warmup_steps - cooldown_steps, + "num_decay_steps": cooldown_steps + } + else: + scheduler_specific_kwargs = {} + scheduler = get_scheduler( name=config.schedule_type, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=(total_num_steps if config.schedule_type != "inverse_sqrt" else None), + scheduler_specific_kwargs=scheduler_specific_kwargs ) return scheduler