From 5ea6aec8827eabf2a7569d32780ebf3510ba0f6e Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 30 Dec 2024 20:39:27 +0900 Subject: [PATCH 1/4] Don't iterate over samples in the same order every epoch in a distributed training example (#9900) * Adds a call to `set_epoch` to shuffle samples across epochs. * Applies average instead of sum reduction across ranks to make the scale of the loss consistent across a different number of GPUs. * Calls `optimizer.zero_grad` at the end to release gradients before evaluation loops. * Includes minor decorative changes, e.g., type annotations. --- examples/multi_gpu/distributed_batching.py | 88 ++++++++++++++-------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/examples/multi_gpu/distributed_batching.py b/examples/multi_gpu/distributed_batching.py index f5c05a176823..b242499d3e76 100644 --- a/examples/multi_gpu/distributed_batching.py +++ b/examples/multi_gpu/distributed_batching.py @@ -1,36 +1,35 @@ import os +import os.path as osp import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F -from ogb.graphproppred import Evaluator -from ogb.graphproppred import PygGraphPropPredDataset as Dataset +from ogb.graphproppred import Evaluator, PygGraphPropPredDataset from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch.nn import BatchNorm1d as BatchNorm from torch.nn import Linear, ReLU, Sequential from torch.nn.parallel import DistributedDataParallel from torch.utils.data.distributed import DistributedSampler +from torch_sparse import SparseTensor import torch_geometric.transforms as T from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, global_mean_pool -from torch_geometric.typing import WITH_TORCH_SPARSE - -if not WITH_TORCH_SPARSE: - quit("This example requires 'torch-sparse'") class GIN(torch.nn.Module): - def __init__(self, hidden_channels, out_channels, num_layers=3, - dropout=0.5): + def __init__( + self, + hidden_channels: int, + out_channels: int, + num_layers: int = 3, + dropout: float = 0.5, + ) -> None: super().__init__() - self.dropout = dropout - self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) - self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( @@ -45,7 +44,12 @@ def __init__(self, hidden_channels, out_channels, num_layers=3, self.lin = Linear(hidden_channels, out_channels) - def forward(self, x, adj_t, batch): + def forward( + self, + x: torch.Tensor, + adj_t: SparseTensor, + batch: torch.Tensor, + ) -> torch.Tensor: x = self.atom_encoder(x) edge_attr = adj_t.coo()[2] adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo') @@ -59,21 +63,29 @@ def forward(self, x, adj_t, batch): return x -def run(rank, world_size: int, dataset_name: str, root: str): +def run(rank: int, world_size: int, dataset_name: str, root: str) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) - dataset = Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + dataset = PygGraphPropPredDataset( + dataset_name, + root=root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) split_idx = dataset.get_idx_split() evaluator = Evaluator(dataset_name) train_dataset = dataset[split_idx['train']] - train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, - rank=rank) - train_loader = DataLoader(train_dataset, batch_size=128, - sampler=train_sampler) + train_loader = DataLoader( + train_dataset, + batch_size=128, + sampler=DistributedSampler( + train_dataset, + shuffle=True, + drop_last=True, + ), + ) torch.manual_seed(12345) model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank) @@ -87,20 +99,22 @@ def run(rank, world_size: int, dataset_name: str, root: str): for epoch in range(1, 51): model.train() - - total_loss = torch.zeros(2).to(rank) + train_loader.sampler.set_epoch(epoch) + total_loss = torch.zeros(2, device=rank) for data in train_loader: data = data.to(rank) - optimizer.zero_grad() logits = model(data.x, data.adj_t, data.batch) loss = criterion(logits, data.y.to(torch.float)) loss.backward() optimizer.step() - total_loss[0] += float(loss) * logits.size(0) - total_loss[1] += data.num_graphs + optimizer.zero_grad() - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - loss = float(total_loss[0] / total_loss[1]) + with torch.no_grad(): + total_loss[0] += loss * logits.size(0) + total_loss[1] += data.num_graphs + + dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) + train_loss = total_loss[0] / total_loss[1] if rank == 0: # We evaluate on a single GPU for now. model.eval() @@ -127,8 +141,10 @@ def run(rank, world_size: int, dataset_name: str, root: str): 'y_true': torch.cat(y_true, dim=0), })['rocauc'] - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' - f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}') + print(f'Epoch: {epoch:03d}, ' + f'Loss: {train_loss:.4f}, ' + f'Val: {val_rocauc:.4f}, ' + f'Test: {test_rocauc:.4f}') dist.barrier() @@ -137,11 +153,19 @@ def run(rank, world_size: int, dataset_name: str, root: str): if __name__ == '__main__': dataset_name = 'ogbg-molhiv' - root = '../../data/OGB' - + root = osp.join( + osp.dirname(__file__), + '..', + '..', + 'data', + 'OGB', + ) # Download and process the dataset on main process. - Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + PygGraphPropPredDataset( + dataset_name, + root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!') From 076db8403243ad4400da196218f731756293da18 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 30 Dec 2024 20:39:56 +0900 Subject: [PATCH 2/4] Fix broken metrics in a `DistributedDataParallel` example (#9896) Fixes a few issues: * Fixes the same three calls to `dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)` (introduced in #8880) that led to the wrong metrics. * Avoids the `int(cuda_tensor)` call in every eval iteration to get rid of the D2H synchronization. * Calls `optimizer.step()` at the end of each training step to release GPU memory for gradients before evaluation loops. * Minor decorative changes: * Adds type annotations. * Adds a progress bar for the training loop. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/multi_gpu/distributed_sampling.py | 67 +++++++++++++--------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/examples/multi_gpu/distributed_sampling.py b/examples/multi_gpu/distributed_sampling.py index 805e6b305728..d3df91000fd0 100644 --- a/examples/multi_gpu/distributed_sampling.py +++ b/examples/multi_gpu/distributed_sampling.py @@ -1,4 +1,5 @@ import os +import os.path as osp from math import ceil import torch @@ -7,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm from torch_geometric.datasets import Reddit from torch_geometric.loader import NeighborLoader @@ -14,10 +16,14 @@ class SAGE(torch.nn.Module): - def __init__(self, in_channels: int, hidden_channels: int, - out_channels: int, num_layers: int = 2): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int = 2, + ) -> None: super().__init__() - self.convs = torch.nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(num_layers - 2): @@ -34,20 +40,25 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: @torch.no_grad() -def test(loader, model, rank): +def test( + loader: NeighborLoader, + model: DistributedDataParallel, + rank: int, +) -> Tensor: model.eval() - - total_correct = total_examples = 0 + total_correct = torch.tensor(0, dtype=torch.long, device=rank) + total_examples = 0 for i, batch in enumerate(loader): out = model(batch.x, batch.edge_index.to(rank)) pred = out[:batch.batch_size].argmax(dim=-1) y = batch.y[:batch.batch_size].to(rank) - total_correct += int((pred == y).sum()) + total_correct += (pred == y).sum() total_examples += batch.batch_size - return torch.tensor(total_correct / total_examples, device=rank) + + return total_correct / total_examples -def run(rank, world_size, dataset): +def run(rank: int, world_size: int, dataset: Reddit) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) @@ -94,17 +105,19 @@ def run(rank, world_size, dataset): for epoch in range(1, 21): model.train() - for batch in train_loader: - optimizer.zero_grad() + for batch in tqdm( + train_loader, + desc=f'Epoch {epoch:02d}', + disable=rank != 0, + ): out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size] loss = F.cross_entropy(out, batch.y[:batch.batch_size]) loss.backward() optimizer.step() - - dist.barrier() + optimizer.zero_grad() if rank == 0: - print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') + print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}') if epoch % 5 == 0: train_acc = test(train_loader, model, rank) @@ -112,25 +125,27 @@ def run(rank, world_size, dataset): test_acc = test(test_loader, model, rank) if world_size > 1: - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) - train_acc /= world_size - val_acc /= world_size - test_acc /= world_size + dist.all_reduce(train_acc, op=dist.ReduceOp.AVG) + dist.all_reduce(val_acc, op=dist.ReduceOp.AVG) + dist.all_reduce(test_acc, op=dist.ReduceOp.AVG) if rank == 0: - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - dist.barrier() + print(f'Train acc: {train_acc:.4f}, ' + f'Val acc: {val_acc:.4f}, ' + f'Test acc: {test_acc:.4f}') dist.destroy_process_group() if __name__ == '__main__': - dataset = Reddit('../../data/Reddit') - + path = osp.join( + osp.dirname(__file__), + '..', + '..', + 'data', + 'Reddit', + ) + dataset = Reddit(path) world_size = torch.cuda.device_count() print("Let's use", world_size, "GPUs!") mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True) From c300f38e4f28f550456c65f7e08e052dc40282cd Mon Sep 17 00:00:00 2001 From: kolmiw <72548086+kolmiw@users.noreply.github.com> Date: Mon, 30 Dec 2024 20:27:54 +0100 Subject: [PATCH 3/4] Graphgym documentation completion (#9905) I changed a single line of a comment. The compute loss function's description was cut in half mid-sentence. This change shouldn't change the code by any means though --- torch_geometric/graphgym/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/graphgym/loss.py b/torch_geometric/graphgym/loss.py index db38ce6e1e42..4e3dfd7029a6 100644 --- a/torch_geometric/graphgym/loss.py +++ b/torch_geometric/graphgym/loss.py @@ -10,7 +10,7 @@ def compute_loss(pred, true): Args: pred (torch.tensor): Unnormalized prediction - true (torch.tensor): Grou + true (torch.tensor): Ground truth labels Returns: Loss, normalized prediction score From 5d1b8987e52ddead470973c611c5b6b1bf935d99 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 5 Jan 2025 03:28:54 +0900 Subject: [PATCH 4/4] Update type annotations (#9917) Unblocks merging to master. --- .github/workflows/linting.yml | 2 +- torch_geometric/datasets/git_mol_dataset.py | 4 ++-- torch_geometric/datasets/molecule_gpt_dataset.py | 9 ++++++--- torch_geometric/utils/smiles.py | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 3653abdea195..f1f16a5e9a9c 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -25,7 +25,7 @@ jobs: # Skip workflow if only certain files have been changed. - name: Get changed files id: changed-files-specific - uses: tj-actions/changed-files@v41 + uses: tj-actions/changed-files@v45 with: files: | benchmark/** diff --git a/torch_geometric/datasets/git_mol_dataset.py b/torch_geometric/datasets/git_mol_dataset.py index 4b7cfa78117c..872b0d1f5d2d 100644 --- a/torch_geometric/datasets/git_mol_dataset.py +++ b/torch_geometric/datasets/git_mol_dataset.py @@ -187,7 +187,7 @@ def process(self) -> None: img = self.img_transform(img).unsqueeze(0) # graph atom_features_list = [] - for atom in mol.GetAtoms(): # type: ignore + for atom in mol.GetAtoms(): atom_feature = [ safe_index( allowable_features['possible_atomic_num_list'], @@ -219,7 +219,7 @@ def process(self) -> None: edges_list = [] edge_features_list = [] - for bond in mol.GetBonds(): # type: ignore + for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_feature = [ safe_index( diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py index b1da09f38570..fed2fe503600 100644 --- a/torch_geometric/datasets/molecule_gpt_dataset.py +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -122,7 +122,10 @@ def clean_up_description(description: str) -> str: return first_sentence -def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]: +def extract_name( + name_raw: str, + description: str, +) -> Tuple[Optional[str], str, str]: first_sentence = clean_up_description(description) splitter = ' -- -- ' @@ -446,12 +449,12 @@ def extract_one_SDF_file(block_id: int) -> None: x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 - for atom in m.GetAtoms() # type: ignore + for atom in m.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] - for bond in m.GetBonds(): # type: ignore + for bond in m.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] diff --git a/torch_geometric/utils/smiles.py b/torch_geometric/utils/smiles.py index 608618fc44d9..6547be968582 100644 --- a/torch_geometric/utils/smiles.py +++ b/torch_geometric/utils/smiles.py @@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data': assert isinstance(mol, Chem.Mol) xs: List[List[int]] = [] - for atom in mol.GetAtoms(): # type: ignore + for atom in mol.GetAtoms(): row: List[int] = [] row.append(x_map['atomic_num'].index(atom.GetAtomicNum())) row.append(x_map['chirality'].index(str(atom.GetChiralTag()))) @@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data': x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] - for bond in mol.GetBonds(): # type: ignore + for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx()