Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Sikan Li committed Jun 28, 2024
1 parent 748b9f4 commit a74ced0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
27 changes: 14 additions & 13 deletions gns/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,30 @@


def setup(local_rank: int):
"""Initializes distributed training.
"""
"""Initializes distributed training."""
# Initialize group, blocks until all processes join.
torch.distributed.init_process_group(
backend="nccl",
init_method='env://',
init_method="env://",
)
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(0)
verbose = (dist.get_rank() == 0)
verbose = dist.get_rank() == 0

if verbose:
print('Collecting env info...')
print("Collecting env info...")
print(collect_env.get_pretty_env_info())
print()

for r in range(torch.distributed.get_world_size()):
if r == torch.distributed.get_rank():
print(
f'Global rank {torch.distributed.get_rank()} initialized: '
f'local_rank = {local_rank}, '
f'world_size = {torch.distributed.get_world_size()}',
f"Global rank {torch.distributed.get_rank()} initialized: "
f"local_rank = {local_rank}, "
f"world_size = {torch.distributed.get_world_size()}",
)
return verbose, world_size



def cleanup():
Expand Down Expand Up @@ -68,14 +66,17 @@ def get_data_distributed_dataloader_by_samples(
shuffle (bool): Whether to shuffle dataset.
"""
dataset = data_loader.SamplesDataset(path, input_length_sequence)
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=shuffle)

sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=shuffle,
)

return torch.utils.data.DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
collate_fn=data_loader.collate_fn,
)


14 changes: 6 additions & 8 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,7 @@ def train(rank, cfg, world_size, device, verbose):
cfg.data.noise_std,
rank,
)
simulator = DDP(
serial_simulator.to('cuda'), device_ids=[rank]
)
simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank])
optimizer = torch.optim.Adam(
simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size
)
Expand Down Expand Up @@ -559,7 +557,7 @@ def train(rank, cfg, world_size, device, verbose):
)
pbar.update(1)

if verbose and step % cfg.training.save_steps == 0:
if verbose and step % cfg.training.save_steps == 0:
save_model_and_train_state(
verbose,
device,
Expand Down Expand Up @@ -751,9 +749,9 @@ def validation(simulator, example, n_features, cfg, rank, device_id):
def main(cfg: Config):
"""Train or evaluates the model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if 'LOCAL_RANK' in os.environ:
local_rank = int(os.environ['LOCAL_RANK'])
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])

if cfg.mode == "train":
# If model_path does not exist create new directory.
if not os.path.exists(cfg.model.path):
Expand All @@ -765,7 +763,7 @@ def main(cfg: Config):

# Train on gpu
if device == torch.device("cuda"):
torch.multiprocessing.set_start_method('spawn')
torch.multiprocessing.set_start_method("spawn")
verbose, world_size = distribute.setup(local_rank)

# Train on cpu
Expand Down

0 comments on commit a74ced0

Please sign in to comment.