Skip to content

Commit

Permalink
enable model logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill Shmilovich committed Mar 3, 2023
1 parent 3a21793 commit 7f63181
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions molgen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import CSVLogger
from typing import Union
from molgen.modules import (
SimpleGenerator,
Expand Down Expand Up @@ -191,6 +192,7 @@ def fit(
condition_data: Union[torch.Tensor, list],
batch_size: int = 1000,
max_epochs: int = 100,
log: Union[str, bool] = False,
**kwargs,
):
"""
Expand All @@ -212,9 +214,16 @@ def fit(
max_epochs : int, default = 100
maximum number of epochs to train for
log : str or bool, default = False
if the results of the training should be logged. If True logs are by default saved in CSV format
to the directory `./molgen_logs/version_x/`, where `x` increments based on what has been
logged already. If a string is passed the saving directory is created based on the provided name
`./molgen_logs/{log}/`
**kwargs:
additional keyword arguments to be passed to the the Lightning `Trainer`
"""
kwargs.get("enable_checkpointing", False)
datamodule = GANDataModule(
feature_data=feature_data,
condition_data=condition_data,
Expand All @@ -235,8 +244,13 @@ def fit(
devices=1,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
max_epochs=max_epochs,
logger=False,
enable_checkpointing=False,
logger=False
if log is False
else CSVLogger(
save_dir="./",
name="molgen_logs",
version=None if not isinstance(log, str) else log,
),
**kwargs,
)
self.trainer_.fit(self, datamodule)
Expand Down Expand Up @@ -404,6 +418,7 @@ def step_ema(self):

def training_step(self, batch, batch_idx):
loss = self.model(batch[0])
self.log("loss", loss)
return loss

def optimizer_step(self, *args, **kwargs):
Expand All @@ -417,6 +432,7 @@ def fit(
condition_data: Union[torch.Tensor, list],
batch_size: int = 1000,
max_epochs: int = 100,
log: Union[str, bool] = False,
**kwargs,
):
"""
Expand All @@ -438,9 +454,16 @@ def fit(
max_epochs : int, default = 100
maximum number of epochs to train for
log : str or bool, default = False
if the results of the training should be logged. If True logs are by default saved in CSV format
to the directory `./molgen_logs/version_x/`, where `x` increments based on what has been
logged already. If a string is passed the saving directory is created based on the provided name
`./molgen_logs/{log}/`
**kwargs:
additional keyword arguments to be passed to the the Lightning `Trainer`
"""
kwargs.get("enable_checkpointing", False)
datamodule = DDPMDataModule(
feature_data=feature_data,
condition_data=condition_data,
Expand All @@ -461,8 +484,13 @@ def fit(
devices=1,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
max_epochs=max_epochs,
logger=False,
enable_checkpointing=False,
logger=False
if log is False
else CSVLogger(
save_dir="./",
name="molgen_logs",
version=None if not isinstance(log, str) else log,
),
**kwargs,
)
self.trainer_.fit(self, datamodule)
Expand Down

0 comments on commit 7f63181

Please sign in to comment.