Skip to content

Commit

Permalink
fix optimization with efficient checkpoint module
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Jul 16, 2024
1 parent 93b5526 commit 1f97122
Showing 1 changed file with 2 additions and 17 deletions.
19 changes: 2 additions & 17 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,23 +644,8 @@ def __init__(self, expert_model, ref_expert_model, **kwargs):
self.expert_model = expert_model
self.ref_expert_model = ref_expert_model
self.trainable_param_names = kwargs.get("trainable_param_names", None)

def configure_optimizers(self):
params = []
# for param_name, param in self.named_parameters():
# param.requires_grad = False
# if self.trainable_param_names and re.fullmatch(
# self.trainable_param_names, param_name
# ):
# param.requires_grad = True
# params.append(param)

# logger.info(f"Setting {param_name} to trainable.")
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3
)

return optimizer
# log hyperparameters
self.save_hyperparameters(kwargs)

def training_step(self, batch, _):

Expand Down

0 comments on commit 1f97122

Please sign in to comment.