Skip to content

Commit

Permalink
Fix validation dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jun 29, 2024
1 parent 50fe084 commit 27919d1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def train(rank, cfg, world_size, device, verbose):
cfg: configuration dictionary
world_size: total number of ranks
device: torch device type
verbose: gloabl rank 0 or cpu
verbose: global rank 0 or cpu
"""
device_id = rank if device == torch.device("cuda") else device

Expand Down Expand Up @@ -476,6 +476,7 @@ def train(rank, cfg, world_size, device, verbose):

# Determine if we're using distributed training
is_distributed = device == torch.device("cuda") and world_size > 1
# Load datasets
train_dl, valid_dl, n_features = load_datasets(cfg, is_distributed)

print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")
Expand Down Expand Up @@ -627,7 +628,7 @@ def train(rank, cfg, world_size, device, verbose):
train_loss_hist.append((epoch, avg_loss.item()))

if cfg.training.validation_interval is not None:
sampled_valid_example = next(iter(dl_valid))
sampled_valid_example = next(iter(valid_dl))
epoch_valid_loss = validation(
simulator, sampled_valid_example, n_features, cfg, rank, device_id
)
Expand Down

0 comments on commit 27919d1

Please sign in to comment.