diff --git a/gns/distribute.py b/gns/distribute.py index 00c6e23..9fb7869 100644 --- a/gns/distribute.py +++ b/gns/distribute.py @@ -7,32 +7,30 @@ def setup(local_rank: int): - """Initializes distributed training. - """ + """Initializes distributed training.""" # Initialize group, blocks until all processes join. torch.distributed.init_process_group( backend="nccl", - init_method='env://', + init_method="env://", ) world_size = dist.get_world_size() torch.cuda.set_device(local_rank) torch.cuda.manual_seed(0) - verbose = (dist.get_rank() == 0) + verbose = dist.get_rank() == 0 if verbose: - print('Collecting env info...') + print("Collecting env info...") print(collect_env.get_pretty_env_info()) print() for r in range(torch.distributed.get_world_size()): if r == torch.distributed.get_rank(): print( - f'Global rank {torch.distributed.get_rank()} initialized: ' - f'local_rank = {local_rank}, ' - f'world_size = {torch.distributed.get_world_size()}', + f"Global rank {torch.distributed.get_rank()} initialized: " + f"local_rank = {local_rank}, " + f"world_size = {torch.distributed.get_world_size()}", ) return verbose, world_size - def cleanup(): @@ -68,8 +66,13 @@ def get_data_distributed_dataloader_by_samples( shuffle (bool): Whether to shuffle dataset. """ dataset = data_loader.SamplesDataset(path, input_length_sequence) - sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=shuffle) - + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=shuffle, + ) + return torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, @@ -77,5 +80,3 @@ def get_data_distributed_dataloader_by_samples( pin_memory=True, collate_fn=data_loader.collate_fn, ) - - diff --git a/gns/train.py b/gns/train.py index d92d5d8..f605398 100644 --- a/gns/train.py +++ b/gns/train.py @@ -314,9 +314,7 @@ def train(rank, cfg, world_size, device, verbose): cfg.data.noise_std, rank, ) - simulator = DDP( - serial_simulator.to('cuda'), device_ids=[rank] - ) + simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank]) optimizer = torch.optim.Adam( simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size ) @@ -559,7 +557,7 @@ def train(rank, cfg, world_size, device, verbose): ) pbar.update(1) - if verbose and step % cfg.training.save_steps == 0: + if verbose and step % cfg.training.save_steps == 0: save_model_and_train_state( verbose, device, @@ -751,9 +749,9 @@ def validation(simulator, example, n_features, cfg, rank, device_id): def main(cfg: Config): """Train or evaluates the model.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if 'LOCAL_RANK' in os.environ: - local_rank = int(os.environ['LOCAL_RANK']) - + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + if cfg.mode == "train": # If model_path does not exist create new directory. if not os.path.exists(cfg.model.path): @@ -765,7 +763,7 @@ def main(cfg: Config): # Train on gpu if device == torch.device("cuda"): - torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.set_start_method("spawn") verbose, world_size = distribute.setup(local_rank) # Train on cpu