Skip to content

Commit

Permalink
seeds
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 8, 2024
1 parent f246f3d commit fe0e33f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ test: True
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null
seed: 42
4 changes: 2 additions & 2 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ devices: [0]
# precision: 16

# perform a validation loop every N training epochs
check_val_every_n_epoch: 1
check_val_every_n_epoch: 5

# set True to to ensure deterministic results
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
deterministic: True

# inference mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad`
# during evaluation (``validate``/``test``/``predict``).
Expand Down
38 changes: 19 additions & 19 deletions topobenchmarkx/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import random
from typing import Any, Optional

import hydra
Expand All @@ -20,13 +19,13 @@
infer_in_channels,
)

# OmegaConf.register_new_resolver("get_default_transform", get_default_transform)
# OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric)
# OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode)
# OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels)
# OmegaConf.register_new_resolver(
# "parameter_multiplication", lambda x, y: int(int(x) * int(y))
# )
OmegaConf.register_new_resolver("get_default_transform", get_default_transform)
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric)
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode)
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels)
OmegaConf.register_new_resolver(
"parameter_multiplication", lambda x, y: int(int(x) * int(y))
)
from topobenchmarkx.data.dataloader_fullbatch import DefaultDataModule
from topobenchmarkx.utils import (
RankedLogger,
Expand Down Expand Up @@ -73,8 +72,15 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
"""

# Set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
#if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
# Seed for torch
torch.manual_seed(cfg.seed)
# Seed for numpy
np.random.seed(cfg.seed)
# Seed for python random
random.seed(cfg.seed)


# Instantiate and load dataset
dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False)
Expand Down Expand Up @@ -194,12 +200,6 @@ def main(cfg: DictConfig) -> Optional[float]:


if __name__ == "__main__":
OmegaConf.register_new_resolver("get_default_transform", get_default_transform)
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric)
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode)
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels)
OmegaConf.register_new_resolver(
"parameter_multiplication", lambda x, y: int(int(x) * int(y))
)

main()

0 comments on commit fe0e33f

Please sign in to comment.