Skip to content

Commit

Permalink
Merge branch 'master' into feature/add-gkan-layer
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jan 6, 2025
2 parents afeabcf + 5d1b898 commit 2a95bc8
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
# Skip workflow if only certain files have been changed.
- name: Get changed files
id: changed-files-specific
uses: tj-actions/changed-files@v41
uses: tj-actions/changed-files@v45
with:
files: |
benchmark/**
Expand Down
88 changes: 56 additions & 32 deletions examples/multi_gpu/distributed_batching.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import os
import os.path as osp

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from ogb.graphproppred import Evaluator
from ogb.graphproppred import PygGraphPropPredDataset as Dataset
from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch.nn import BatchNorm1d as BatchNorm
from torch.nn import Linear, ReLU, Sequential
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch_sparse import SparseTensor

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_mean_pool
from torch_geometric.typing import WITH_TORCH_SPARSE

if not WITH_TORCH_SPARSE:
quit("This example requires 'torch-sparse'")


class GIN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers=3,
dropout=0.5):
def __init__(
self,
hidden_channels: int,
out_channels: int,
num_layers: int = 3,
dropout: float = 0.5,
) -> None:
super().__init__()

self.dropout = dropout

self.atom_encoder = AtomEncoder(hidden_channels)
self.bond_encoder = BondEncoder(hidden_channels)

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
nn = Sequential(
Expand All @@ -45,7 +44,12 @@ def __init__(self, hidden_channels, out_channels, num_layers=3,

self.lin = Linear(hidden_channels, out_channels)

def forward(self, x, adj_t, batch):
def forward(
self,
x: torch.Tensor,
adj_t: SparseTensor,
batch: torch.Tensor,
) -> torch.Tensor:
x = self.atom_encoder(x)
edge_attr = adj_t.coo()[2]
adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo')
Expand All @@ -59,21 +63,29 @@ def forward(self, x, adj_t, batch):
return x


def run(rank, world_size: int, dataset_name: str, root: str):
def run(rank: int, world_size: int, dataset_name: str, root: str) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)

dataset = Dataset(dataset_name, root,
pre_transform=T.ToSparseTensor(attr='edge_attr'))
dataset = PygGraphPropPredDataset(
dataset_name,
root=root,
pre_transform=T.ToSparseTensor(attr='edge_attr'),
)
split_idx = dataset.get_idx_split()
evaluator = Evaluator(dataset_name)

train_dataset = dataset[split_idx['train']]
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size,
rank=rank)
train_loader = DataLoader(train_dataset, batch_size=128,
sampler=train_sampler)
train_loader = DataLoader(
train_dataset,
batch_size=128,
sampler=DistributedSampler(
train_dataset,
shuffle=True,
drop_last=True,
),
)

torch.manual_seed(12345)
model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank)
Expand All @@ -87,20 +99,22 @@ def run(rank, world_size: int, dataset_name: str, root: str):

for epoch in range(1, 51):
model.train()

total_loss = torch.zeros(2).to(rank)
train_loader.sampler.set_epoch(epoch)
total_loss = torch.zeros(2, device=rank)
for data in train_loader:
data = data.to(rank)
optimizer.zero_grad()
logits = model(data.x, data.adj_t, data.batch)
loss = criterion(logits, data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss[0] += float(loss) * logits.size(0)
total_loss[1] += data.num_graphs
optimizer.zero_grad()

dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
loss = float(total_loss[0] / total_loss[1])
with torch.no_grad():
total_loss[0] += loss * logits.size(0)
total_loss[1] += data.num_graphs

dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)
train_loss = total_loss[0] / total_loss[1]

if rank == 0: # We evaluate on a single GPU for now.
model.eval()
Expand All @@ -127,8 +141,10 @@ def run(rank, world_size: int, dataset_name: str, root: str):
'y_true': torch.cat(y_true, dim=0),
})['rocauc']

print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}')
print(f'Epoch: {epoch:03d}, '
f'Loss: {train_loss:.4f}, '
f'Val: {val_rocauc:.4f}, '
f'Test: {test_rocauc:.4f}')

dist.barrier()

Expand All @@ -137,11 +153,19 @@ def run(rank, world_size: int, dataset_name: str, root: str):

if __name__ == '__main__':
dataset_name = 'ogbg-molhiv'
root = '../../data/OGB'

root = osp.join(
osp.dirname(__file__),
'..',
'..',
'data',
'OGB',
)
# Download and process the dataset on main process.
Dataset(dataset_name, root,
pre_transform=T.ToSparseTensor(attr='edge_attr'))
PygGraphPropPredDataset(
dataset_name,
root,
pre_transform=T.ToSparseTensor(attr='edge_attr'),
)

world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')
Expand Down
67 changes: 41 additions & 26 deletions examples/multi_gpu/distributed_sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import os.path as osp
from math import ceil

import torch
Expand All @@ -7,17 +8,22 @@
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv


class SAGE(torch.nn.Module):
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int = 2):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
) -> None:
super().__init__()

self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
Expand All @@ -34,20 +40,25 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:


@torch.no_grad()
def test(loader, model, rank):
def test(
loader: NeighborLoader,
model: DistributedDataParallel,
rank: int,
) -> Tensor:
model.eval()

total_correct = total_examples = 0
total_correct = torch.tensor(0, dtype=torch.long, device=rank)
total_examples = 0
for i, batch in enumerate(loader):
out = model(batch.x, batch.edge_index.to(rank))
pred = out[:batch.batch_size].argmax(dim=-1)
y = batch.y[:batch.batch_size].to(rank)
total_correct += int((pred == y).sum())
total_correct += (pred == y).sum()
total_examples += batch.batch_size
return torch.tensor(total_correct / total_examples, device=rank)

return total_correct / total_examples


def run(rank, world_size, dataset):
def run(rank: int, world_size: int, dataset: Reddit) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
Expand Down Expand Up @@ -94,43 +105,47 @@ def run(rank, world_size, dataset):

for epoch in range(1, 21):
model.train()
for batch in train_loader:
optimizer.zero_grad()
for batch in tqdm(
train_loader,
desc=f'Epoch {epoch:02d}',
disable=rank != 0,
):
out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()

dist.barrier()
optimizer.zero_grad()

if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}')

if epoch % 5 == 0:
train_acc = test(train_loader, model, rank)
val_acc = test(val_loader, model, rank)
test_acc = test(test_loader, model, rank)

if world_size > 1:
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
train_acc /= world_size
val_acc /= world_size
test_acc /= world_size
dist.all_reduce(train_acc, op=dist.ReduceOp.AVG)
dist.all_reduce(val_acc, op=dist.ReduceOp.AVG)
dist.all_reduce(test_acc, op=dist.ReduceOp.AVG)

if rank == 0:
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')

dist.barrier()
print(f'Train acc: {train_acc:.4f}, '
f'Val acc: {val_acc:.4f}, '
f'Test acc: {test_acc:.4f}')

dist.destroy_process_group()


if __name__ == '__main__':
dataset = Reddit('../../data/Reddit')

path = osp.join(
osp.dirname(__file__),
'..',
'..',
'data',
'Reddit',
)
dataset = Reddit(path)
world_size = torch.cuda.device_count()
print("Let's use", world_size, "GPUs!")
mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)
4 changes: 2 additions & 2 deletions torch_geometric/datasets/git_mol_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process(self) -> None:
img = self.img_transform(img).unsqueeze(0)
# graph
atom_features_list = []
for atom in mol.GetAtoms(): # type: ignore
for atom in mol.GetAtoms():
atom_feature = [
safe_index(
allowable_features['possible_atomic_num_list'],
Expand Down Expand Up @@ -219,7 +219,7 @@ def process(self) -> None:

edges_list = []
edge_features_list = []
for bond in mol.GetBonds(): # type: ignore
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_feature = [
safe_index(
Expand Down
9 changes: 6 additions & 3 deletions torch_geometric/datasets/molecule_gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def clean_up_description(description: str) -> str:
return first_sentence


def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]:
def extract_name(
name_raw: str,
description: str,
) -> Tuple[Optional[str], str, str]:
first_sentence = clean_up_description(description)

splitter = ' -- -- '
Expand Down Expand Up @@ -446,12 +449,12 @@ def extract_one_SDF_file(block_id: int) -> None:

x: torch.Tensor = torch.tensor([
types[atom.GetSymbol()] if atom.GetSymbol() in types else 5
for atom in m.GetAtoms() # type: ignore
for atom in m.GetAtoms()
])
x = one_hot(x, num_classes=len(types), dtype=torch.float)

rows, cols, edge_types = [], [], []
for bond in m.GetBonds(): # type: ignore
for bond in m.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_types += [bonds[bond.GetBondType()]] * 2
rows += [i, j]
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/graphgym/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def compute_loss(pred, true):
Args:
pred (torch.tensor): Unnormalized prediction
true (torch.tensor): Grou
true (torch.tensor): Ground truth labels
Returns: Loss, normalized prediction score
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
assert isinstance(mol, Chem.Mol)

xs: List[List[int]] = []
for atom in mol.GetAtoms(): # type: ignore
for atom in mol.GetAtoms():
row: List[int] = []
row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
Expand All @@ -108,7 +108,7 @@ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

edge_indices, edge_attrs = [], []
for bond in mol.GetBonds(): # type: ignore
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()

Expand Down

0 comments on commit 2a95bc8

Please sign in to comment.