diff --git a/code/experiments/utils/loggers.py b/code/experiments/utils/loggers.py index 127e5d5..9d07191 100644 --- a/code/experiments/utils/loggers.py +++ b/code/experiments/utils/loggers.py @@ -15,10 +15,20 @@ def get_wandb_logger( project_id: str = "mantra-dev", ): wandb_logger = WandbLogger(project=project_id, save_dir=save_dir) + return wandb_logger + +def update_wandb_logger( + wandb_logger, + task_name: TaskType, + save_dir="./lightning_logs", + model_name: str = None, + node_features: str = None, + run_id: str = None, + project_id: str = "mantra-dev", + ): wandb_logger.experiment.config["task"] = task_name.lower() wandb_logger.experiment.config["run_id"] = run_id wandb_logger.experiment.config["node_features"] = node_features if model_name is not None: - wandb_logger.experiment.config["model_name"] = model_name - return wandb_logger + wandb_logger.experiment.config["model_name"] = model_name \ No newline at end of file diff --git a/code/experiments/utils/run_experiment.py b/code/experiments/utils/run_experiment.py index 721b58c..d70d827 100644 --- a/code/experiments/utils/run_experiment.py +++ b/code/experiments/utils/run_experiment.py @@ -13,7 +13,7 @@ from typing import Dict, Optional, Tuple, List from datasets.simplicial import SimplicialDataModule from models.base import BaseModel -from experiments.utils.loggers import get_wandb_logger +from experiments.utils.loggers import get_wandb_logger, update_wandb_logger import lightning as L import uuid from datasets.transforms import transforms_lookup @@ -88,8 +88,19 @@ def get_setup( log_every_n_steps=config.trainer_config.log_every_n_steps, fast_dev_run=False, default_root_dir=data_dir, + devices=3, ) + if use_logger and trainer.global_rank == 0: + update_wandb_logger( + logger, + task_name=config.task_type.name, + model_name=config.conf_model.type.name, + node_features=config.transforms.name, + run_id=run_id, + project_id=config.logging.wandb_project_id, + ) + return dm, lit_model, trainer, logger diff --git a/code/generate_configs.sh b/code/generate_configs.sh index cd39a9a..970fbca 100755 --- a/code/generate_configs.sh +++ b/code/generate_configs.sh @@ -1,6 +1,6 @@ #!/bin/bash python ./experiments/generate_configs.py \ - --max_epochs 1 \ + --max_epochs 6 \ --lr 0.001 \ --config_dir "/data/configs" \