Skip to content

Commit

Permalink
made code compatible with running on the server
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed Aug 31, 2024
1 parent 5743c2a commit af8eb1f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
14 changes: 12 additions & 2 deletions code/experiments/utils/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@ def get_wandb_logger(
project_id: str = "mantra-dev",
):
wandb_logger = WandbLogger(project=project_id, save_dir=save_dir)
return wandb_logger

def update_wandb_logger(
wandb_logger,
task_name: TaskType,
save_dir="./lightning_logs",
model_name: str = None,
node_features: str = None,
run_id: str = None,
project_id: str = "mantra-dev",
):
wandb_logger.experiment.config["task"] = task_name.lower()
wandb_logger.experiment.config["run_id"] = run_id
wandb_logger.experiment.config["node_features"] = node_features

if model_name is not None:
wandb_logger.experiment.config["model_name"] = model_name
return wandb_logger
wandb_logger.experiment.config["model_name"] = model_name
13 changes: 12 additions & 1 deletion code/experiments/utils/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Dict, Optional, Tuple, List
from datasets.simplicial import SimplicialDataModule
from models.base import BaseModel
from experiments.utils.loggers import get_wandb_logger
from experiments.utils.loggers import get_wandb_logger, update_wandb_logger
import lightning as L
import uuid
from datasets.transforms import transforms_lookup
Expand Down Expand Up @@ -88,8 +88,19 @@ def get_setup(
log_every_n_steps=config.trainer_config.log_every_n_steps,
fast_dev_run=False,
default_root_dir=data_dir,
devices=3,
)

if use_logger and trainer.global_rank == 0:
update_wandb_logger(
logger,
task_name=config.task_type.name,
model_name=config.conf_model.type.name,
node_features=config.transforms.name,
run_id=run_id,
project_id=config.logging.wandb_project_id,
)

return dm, lit_model, trainer, logger


Expand Down
2 changes: 1 addition & 1 deletion code/generate_configs.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

python ./experiments/generate_configs.py \
--max_epochs 1 \
--max_epochs 6 \
--lr 0.001 \
--config_dir "/data/configs" \

0 comments on commit af8eb1f

Please sign in to comment.