Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/state_sets/configs/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ wandb_track: false
weight_decay: 0.0005
batch_size: 64
lr: 1e-4
gene_decoder_lr: 1e-5
scheduler_type: None
scheduler_step_size: 50
scheduler_gamma: 0.1
scheduler_T_max: 100
scheduler_patience: 10
scheduler_factor: 0.5
scheduler_monitor: val_loss
warmup_epochs: 0
max_steps: 250000
train_seed: 42
val_freq: 5000
Expand Down
183 changes: 178 additions & 5 deletions src/state_sets/sets/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class PerturbationModel(ABC, LightningModule):
pert_dim: Dimension of perturbation embeddings
dropout: Dropout rate
lr: Learning rate for optimizer
gene_decoder_lr: Learning rate for gene decoder
loss_fn: Loss function ('mse' or custom nn.Module)
output_space: 'gene' or 'latent'
"""
Expand All @@ -138,7 +139,8 @@ def __init__(
pert_dim: int,
batch_dim: int = None,
dropout: float = 0.1,
lr: float = 3e-4,
lr: float = 1e-4,
gene_decoder_lr: float = 1e-5,
loss_fn: nn.Module = nn.MSELoss(),
control_pert: str = "non-targeting",
embed_key: Optional[str] = None,
Expand All @@ -147,6 +149,15 @@ def __init__(
batch_size: int = 64,
gene_dim: int = 5000,
hvg_dim: int = 2001,
scheduler_type: Optional[str] = None,
scheduler_step_size: int = 50,
scheduler_gamma: float = 0.1,
scheduler_T_max: int = 100,
scheduler_patience: int = 10,
scheduler_factor: float = 0.5,
scheduler_monitor: str = "val_loss",
warmup_epochs: int = 0,
warmup_start_factor: float = 0.1,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -175,8 +186,22 @@ def __init__(
self.gene_names = gene_names # store the gene names that this model output for gene expression space
self.dropout = dropout
self.lr = lr
self.gene_decoder_lr = gene_decoder_lr
self.loss_fn = get_loss_fn(loss_fn)

# Scheduler settings
self.scheduler_type = scheduler_type
self.scheduler_step_size = scheduler_step_size
self.scheduler_gamma = scheduler_gamma
self.scheduler_T_max = scheduler_T_max
self.scheduler_patience = scheduler_patience
self.scheduler_factor = scheduler_factor
self.scheduler_monitor = scheduler_monitor

# Warmup parameters
self.warmup_epochs = warmup_epochs
self.warmup_start_factor = warmup_start_factor

# this will either decode to hvg space if output space is a gene,
# or to transcriptome space if output space is all. done this way to maintain
# backwards compatibility with the old models
Expand Down Expand Up @@ -318,8 +343,156 @@ def decode_to_gene_space(self, latent_embeds: torch.Tensor, basal_expr: None) ->

def configure_optimizers(self):
"""
Configure a single optimizer for both the main model and the gene decoder.
Configure optimizer and optional scheduler for both the main model and the gene decoder.

Supports the following scheduler types:
- 'step': StepLR - reduces LR by gamma every step_size epochs
- 'cosine': CosineAnnealingLR - cosine annealing schedule
- 'plateau': ReduceLROnPlateau - reduces LR when metric plateaus
- 'exponential': ExponentialLR - exponential decay
- 'linear': LinearLR - linear decay over T_max epochs

Warmup is supported for all scheduler types except 'plateau'.
When warmup_epochs > 0, learning rate starts at warmup_start_factor * lr
and linearly increases to lr over warmup_epochs, then follows the main schedule.
"""
# Use a single optimizer for all parameters
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
# Configure optimizer (same as before)
if self.gene_decoder is not None:
# Get gene decoder parameters
gene_decoder_params = list(self.gene_decoder.parameters())

# Get all other parameters (main model)
main_model_params = [
param for name, param in self.named_parameters()
if not name.startswith("gene_decoder.")
]

# Create parameter groups with different learning rates
param_groups = [
{"params": main_model_params, "lr": self.lr},
{"params": gene_decoder_params, "lr": self.gene_decoder_lr},
]

optimizer = torch.optim.Adam(param_groups)
else:
# Use single learning rate if no gene decoder
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

# If no scheduler specified, return just the optimizer
if self.scheduler_type is None:
return optimizer

# Helper function to create warmup + main scheduler
def create_scheduler_with_warmup(main_scheduler):
if self.warmup_epochs > 0:
# Create warmup scheduler
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=self.warmup_start_factor,
end_factor=1.0,
total_iters=self.warmup_epochs
)

# Combine warmup + main scheduler
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, main_scheduler],
milestones=[self.warmup_epochs]
)
return scheduler
else:
return main_scheduler

# Configure scheduler based on type
if self.scheduler_type == "step":
main_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.scheduler_step_size,
gamma=self.scheduler_gamma
)
scheduler = create_scheduler_with_warmup(main_scheduler)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
}
}

elif self.scheduler_type == "cosine":
main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.scheduler_T_max
)
scheduler = create_scheduler_with_warmup(main_scheduler)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
}
}

elif self.scheduler_type == "plateau":
# Note: Warmup with plateau scheduler is tricky because plateau
# schedules based on metrics, not epochs. We'll use a custom approach.
if self.warmup_epochs > 0:
logger.warning(
"Warmup with ReduceLROnPlateau is not directly supported. "
"Consider using a different scheduler type for warmup."
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=self.scheduler_factor,
patience=self.scheduler_patience,
verbose=True
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": self.scheduler_monitor,
"interval": "epoch",
"frequency": 1,
}
}

elif self.scheduler_type == "exponential":
main_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=self.scheduler_gamma
)
scheduler = create_scheduler_with_warmup(main_scheduler)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
}
}

elif self.scheduler_type == "linear":
main_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0,
end_factor=0.1,
total_iters=self.scheduler_T_max
)
scheduler = create_scheduler_with_warmup(main_scheduler)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
}
}

else:
logger.warning(f"Unknown scheduler type: {self.scheduler_type}. Using no scheduler.")
return optimizer