Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed May 7, 2024
1 parent eb9fa23 commit 01815e1
Showing 1 changed file with 103 additions and 28 deletions.
131 changes: 103 additions & 28 deletions molpipeline/estimators/chemprop/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class ABCChemprop(BaseEstimator, abc.ABC):
"""Wrap Chemprop in a sklearn like object."""

model: MPNN
batch_size: int
n_jobs: int
lightning_trainer: pl.Trainer
trainer_params: dict[str, Any]
model_ckpoint_params: dict[str, Any]

def __init__(
self,
Expand Down Expand Up @@ -63,29 +68,44 @@ def __init__(
checkpoint_callback = [
pl.callbacks.ModelCheckpoint(**self.model_ckpoint_params)
]
self.lightning_trainer = self._set_trainer(
self.trainer_params, lightning_trainer, checkpoint_callback
)
self._set_trainer(self.trainer_params, lightning_trainer, checkpoint_callback)

def _set_trainer(self, trainer_params, lightning_trainer, checkpoint_callback):
if self.trainer_params and lightning_trainer is not None:
def _set_trainer(
self,
trainer_params: dict[str, Any],
lightning_trainer: pl.Trainer | None,
checkpoint_callback: list[pl.callbacks.Callback],
) -> None:
"""Set the trainer for the model.
Parameters
----------
trainer_params : dict[str, Any]
The parameters for the trainer.
lightning_trainer : pl.Trainer | None
The lightning trainer.
checkpoint_callback : list[pl.callbacks.ModelCheckpoint]
The checkpoint callback to use.
"""
if self.trainer_params and lightning_trainer:
raise ValueError(
"You must provide either trainer_params or lightning_trainer."
)
elif not trainer_params and lightning_trainer is None:
lightning_trainer = pl.Trainer(
logger=False,
enable_checkpointing=False,
max_epochs=500,
enable_model_summary=False,
callbacks=checkpoint_callback,
)
elif trainer_params and lightning_trainer is None:
lightning_trainer = pl.Trainer(
**trainer_params, callbacks=checkpoint_callback
)

return lightning_trainer
if lightning_trainer:
self.lightning_trainer = lightning_trainer
return

if not trainer_params:
trainer_params = {
"logger": False,
"enable_checkpointing": False,
"max_epochs": 500,
"enable_model_summary": False,
"callbacks": [],
}
if checkpoint_callback:
trainer_params["callbacks"] = checkpoint_callback
self.lightning_trainer = pl.Trainer(**trainer_params)

def fit(
self,
Expand Down Expand Up @@ -117,26 +137,64 @@ def fit(
self.lightning_trainer.fit(self.model, training_data)
return self

def _update_trainer(args: Any) -> pl.Trainer:
return pl.Trainer(**args)
def set_params(self, **params: Any) -> Self:
"""Set the parameters of the model.
def set_params(self, **params) -> None:
Parameters
----------
**params: Any
The parameters to set.
Returns
-------
Self
The model with the new parameters.
"""
params, self.trainer_params = self._filter_params_trainer(params)
params, self.model_ckpoint_params = self._filter_params_callback(params)
super().set_params(**params)
return self

def get_params(self, deep: bool = False) -> dict[str, Any]:
"""Get the parameters of the model.
Parameters
----------
deep : bool, optional (default=False)
Whether to get the parameters of the model.
def get_params(self, deep: bool = False) -> None:
Returns
-------
dict[str, Any]
The parameters of the model.
"""
params = super().get_params(deep)
for name, value in self.trainer_params.items():
params[f"lightning_trainer__{name}"] = value
for name, value in self.model_ckpoint_params.items():
params[f"callback_modelckpt__{name}"] = value
params["lightning_trainer"] = (
None # set to none as we either have the trainer params or the non-parametrized trainer object (otherwise recursive from JSON fails as trainer + params are set)
)
# set to none as the trainer is created from the parameters
params["lightning_trainer"] = None
return params

def _filter_params_trainer(self, params: dict) -> dict:
@staticmethod
def _filter_params_trainer(
params: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Filter the parameters for the trainer.
Parameters
----------
params : dict[str, Any]
The parameters to filter.
Returns
-------
dict[str, Any]
The filtered parameters for the model.
dict[str, Any]
The filtered parameters for the trainer.
"""
params_trainer = {
k.split("__")[1]: v
for k, v in params.items()
Expand All @@ -147,7 +205,24 @@ def _filter_params_trainer(self, params: dict) -> dict:
}
return params, params_trainer

def _filter_params_callback(self, params: dict) -> dict:
@staticmethod
def _filter_params_callback(
params: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Filter the parameters for the checkpoint callback.
Parameters
----------
params : dict[str, Any]
The parameters to filter.
Returns
-------
dict[str, Any]
The filtered parameters for the model.
dict[str, Any]
The filtered parameters for the checkpoint callback.
"""
params_ckpt = {
k.split("__")[1]: v
for k, v in params.items()
Expand Down

0 comments on commit 01815e1

Please sign in to comment.