diff --git a/src/state_sets/configs/training/default.yaml b/src/state_sets/configs/training/default.yaml index 58d014c8..ac55dbaa 100644 --- a/src/state_sets/configs/training/default.yaml +++ b/src/state_sets/configs/training/default.yaml @@ -2,6 +2,15 @@ wandb_track: false weight_decay: 0.0005 batch_size: 64 lr: 1e-4 +gene_decoder_lr: 1e-5 +scheduler_type: None +scheduler_step_size: 50 +scheduler_gamma: 0.1 +scheduler_T_max: 100 +scheduler_patience: 10 +scheduler_factor: 0.5 +scheduler_monitor: val_loss +warmup_epochs: 0 max_steps: 250000 train_seed: 42 val_freq: 5000 diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 8c475716..21531785 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -126,6 +126,7 @@ class PerturbationModel(ABC, LightningModule): pert_dim: Dimension of perturbation embeddings dropout: Dropout rate lr: Learning rate for optimizer + gene_decoder_lr: Learning rate for gene decoder loss_fn: Loss function ('mse' or custom nn.Module) output_space: 'gene' or 'latent' """ @@ -138,7 +139,8 @@ def __init__( pert_dim: int, batch_dim: int = None, dropout: float = 0.1, - lr: float = 3e-4, + lr: float = 1e-4, + gene_decoder_lr: float = 1e-5, loss_fn: nn.Module = nn.MSELoss(), control_pert: str = "non-targeting", embed_key: Optional[str] = None, @@ -147,6 +149,15 @@ def __init__( batch_size: int = 64, gene_dim: int = 5000, hvg_dim: int = 2001, + scheduler_type: Optional[str] = None, + scheduler_step_size: int = 50, + scheduler_gamma: float = 0.1, + scheduler_T_max: int = 100, + scheduler_patience: int = 10, + scheduler_factor: float = 0.5, + scheduler_monitor: str = "val_loss", + warmup_epochs: int = 0, + warmup_start_factor: float = 0.1, **kwargs, ): super().__init__() @@ -175,8 +186,22 @@ def __init__( self.gene_names = gene_names # store the gene names that this model output for gene expression space self.dropout = dropout self.lr = lr + self.gene_decoder_lr = gene_decoder_lr self.loss_fn = get_loss_fn(loss_fn) + # Scheduler settings + self.scheduler_type = scheduler_type + self.scheduler_step_size = scheduler_step_size + self.scheduler_gamma = scheduler_gamma + self.scheduler_T_max = scheduler_T_max + self.scheduler_patience = scheduler_patience + self.scheduler_factor = scheduler_factor + self.scheduler_monitor = scheduler_monitor + + # Warmup parameters + self.warmup_epochs = warmup_epochs + self.warmup_start_factor = warmup_start_factor + # this will either decode to hvg space if output space is a gene, # or to transcriptome space if output space is all. done this way to maintain # backwards compatibility with the old models @@ -318,8 +343,156 @@ def decode_to_gene_space(self, latent_embeds: torch.Tensor, basal_expr: None) -> def configure_optimizers(self): """ - Configure a single optimizer for both the main model and the gene decoder. + Configure optimizer and optional scheduler for both the main model and the gene decoder. + + Supports the following scheduler types: + - 'step': StepLR - reduces LR by gamma every step_size epochs + - 'cosine': CosineAnnealingLR - cosine annealing schedule + - 'plateau': ReduceLROnPlateau - reduces LR when metric plateaus + - 'exponential': ExponentialLR - exponential decay + - 'linear': LinearLR - linear decay over T_max epochs + + Warmup is supported for all scheduler types except 'plateau'. + When warmup_epochs > 0, learning rate starts at warmup_start_factor * lr + and linearly increases to lr over warmup_epochs, then follows the main schedule. """ - # Use a single optimizer for all parameters - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer + # Configure optimizer (same as before) + if self.gene_decoder is not None: + # Get gene decoder parameters + gene_decoder_params = list(self.gene_decoder.parameters()) + + # Get all other parameters (main model) + main_model_params = [ + param for name, param in self.named_parameters() + if not name.startswith("gene_decoder.") + ] + + # Create parameter groups with different learning rates + param_groups = [ + {"params": main_model_params, "lr": self.lr}, + {"params": gene_decoder_params, "lr": self.gene_decoder_lr}, + ] + + optimizer = torch.optim.Adam(param_groups) + else: + # Use single learning rate if no gene decoder + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + + # If no scheduler specified, return just the optimizer + if self.scheduler_type is None: + return optimizer + + # Helper function to create warmup + main scheduler + def create_scheduler_with_warmup(main_scheduler): + if self.warmup_epochs > 0: + # Create warmup scheduler + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=self.warmup_start_factor, + end_factor=1.0, + total_iters=self.warmup_epochs + ) + + # Combine warmup + main scheduler + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[self.warmup_epochs] + ) + return scheduler + else: + return main_scheduler + + # Configure scheduler based on type + if self.scheduler_type == "step": + main_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=self.scheduler_step_size, + gamma=self.scheduler_gamma + ) + scheduler = create_scheduler_with_warmup(main_scheduler) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } + + elif self.scheduler_type == "cosine": + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.scheduler_T_max + ) + scheduler = create_scheduler_with_warmup(main_scheduler) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } + + elif self.scheduler_type == "plateau": + # Note: Warmup with plateau scheduler is tricky because plateau + # schedules based on metrics, not epochs. We'll use a custom approach. + if self.warmup_epochs > 0: + logger.warning( + "Warmup with ReduceLROnPlateau is not directly supported. " + "Consider using a different scheduler type for warmup." + ) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=self.scheduler_factor, + patience=self.scheduler_patience, + verbose=True + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": self.scheduler_monitor, + "interval": "epoch", + "frequency": 1, + } + } + + elif self.scheduler_type == "exponential": + main_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + gamma=self.scheduler_gamma + ) + scheduler = create_scheduler_with_warmup(main_scheduler) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } + + elif self.scheduler_type == "linear": + main_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1.0, + end_factor=0.1, + total_iters=self.scheduler_T_max + ) + scheduler = create_scheduler_with_warmup(main_scheduler) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } + + else: + logger.warning(f"Unknown scheduler type: {self.scheduler_type}. Using no scheduler.") + return optimizer