From 27919d1fa03bf60ebd05d01e65d05a552c718ce1 Mon Sep 17 00:00:00 2001 From: Krishna Kumar Date: Sat, 29 Jun 2024 10:05:52 -0600 Subject: [PATCH] Fix validation dataloader --- gns/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gns/train.py b/gns/train.py index 79dbf18..2581658 100644 --- a/gns/train.py +++ b/gns/train.py @@ -406,7 +406,7 @@ def train(rank, cfg, world_size, device, verbose): cfg: configuration dictionary world_size: total number of ranks device: torch device type - verbose: gloabl rank 0 or cpu + verbose: global rank 0 or cpu """ device_id = rank if device == torch.device("cuda") else device @@ -476,6 +476,7 @@ def train(rank, cfg, world_size, device, verbose): # Determine if we're using distributed training is_distributed = device == torch.device("cuda") and world_size > 1 + # Load datasets train_dl, valid_dl, n_features = load_datasets(cfg, is_distributed) print(f"rank = {rank}, cuda = {torch.cuda.is_available()}") @@ -627,7 +628,7 @@ def train(rank, cfg, world_size, device, verbose): train_loss_hist.append((epoch, avg_loss.item())) if cfg.training.validation_interval is not None: - sampled_valid_example = next(iter(dl_valid)) + sampled_valid_example = next(iter(valid_dl)) epoch_valid_loss = validation( simulator, sampled_valid_example, n_features, cfg, rank, device_id )