Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 8, 2024
2 parents 8f732c8 + 2b69b20 commit f246f3d
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 2 deletions.
5 changes: 4 additions & 1 deletion topobenchmarkx/data/dataloader_fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@


class MyData(Data):
"""
Data object class that overwrites some methods from torch_geometric.data.Data so that not only sparse matrices with adj in the name can work with the torch_geometric dataloaders.
"""
def is_valid(self, string):
valid_names = ["adj", "incidence", "laplacian"]
for name in valid_names:
Expand Down Expand Up @@ -47,7 +50,7 @@ def collate_fn(batch):
args:
batch - list of (tensor, label)
reutrn:
return:
xs - a tensor of all examples in 'batch' after padding
ys - a LongTensor of all labels in batch
"""
Expand Down
53 changes: 52 additions & 1 deletion topobenchmarkx/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,79 @@


class CustomDataset(torch_geometric.data.Dataset):
r"""Custom dataset to return all the values added to the dataset object.
Parameters
----------
data_lst: list
List of torch_geometric.data.Data objects .
"""
def __init__(self, data_lst):
super().__init__()
self.data_lst = data_lst

def get(self, idx):
r"""Get data object from data list.
Parameters
----------
idx: int
Index of the data object to get.
Returns
-------
tuple
tuple containing a list of all the values for the data and the keys corresponding to the values.
"""
data = self.data_lst[idx]
keys = list(data.keys())
return ([data[key] for key in keys], keys)

def len(self):
r"""Return length of the dataset.
Returns
-------
int
Length of the dataset.
"""
return len(self.data_lst)


class TorchGeometricDataset(torch_geometric.data.Dataset):
r"""Dataset to work with a list of data objects.
Parameters
----------
data_lst: list
List of torch_geometric.data.Data objects .
"""
def __init__(self, data_lst):
super().__init__()
self.data_lst = data_lst

def get(self, idx):
data = self.data_lst[idx]
r"""Get data object from data list.
Parameters
----------
idx: int
Index of the data object to get.
Returns
-------
torch_geometric.data.Data
Data object of corresponding index.
"""
data = self.data_lst[idx]
return data

def len(self):
r"""Return length of the dataset.
Returns
-------
int
Length of the dataset.
"""
return len(self.data_lst)
17 changes: 17 additions & 0 deletions topobenchmarkx/io/load/us_county_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@


def load_us_county_demos(path, year=2012, y_col="Election"):
r"""Load US County Demos dataset
Parameters
----------
path: str
Path to the dataset.
year: int
Year to load the features.
y_col: str
Column to use as label.
Returns
-------
torch_geometric.data.Data
Data object of the graph for the US County Demos dataset.
"""

edges_df = pd.read_csv(f"{path}/county_graph.csv")
stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1")

Expand Down
78 changes: 78 additions & 0 deletions topobenchmarkx/models/encoders/default_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@


class BaseEncoder(torch.nn.Module):
r"""Encoder class that uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers.
Parameters
----------
in_channels: int
Dimension of input features.
out_channels: int
Dimensions of output features.
dropout: float
Percentage of channels to discard between the two linear layers.
"""
def __init__(self, in_channels, out_channels, dropout=0):
super().__init__()
self.linear1 = torch.nn.Linear(in_channels, out_channels)
Expand All @@ -15,6 +26,21 @@ def __init__(self, in_channels, out_channels, dropout=0):
self.dropout = torch.nn.Dropout(dropout)

def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
r"""
Forward pass
Parameters
----------
x: torch.Tensor
Input tensor of dimensions [N, in_channels].
batch: torch.Tensor
The batch vector which assigns each element to a specific example.
Returns
-------
torch.Tensor
Output tensor of shape [N, out_channels].
"""
x = self.linear1(x)
x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x)
x = self.dropout(self.relu(x))
Expand All @@ -23,6 +49,19 @@ def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:


class BaseFeatureEncoder(AbstractInitFeaturesEncoder):
r"""Encoder class to apply BaseEncoder to the features of higher order structures.
Parameters
----------
in_channels: list(int)
Input dimensions for the features.
out_channels: list(int)
Output dimensions for the features.
proj_dropout: float
Dropout for the BaseEncoders.
selected_dimensions: list(int)
List of indexes to apply the BaseEncoders to.
"""
def __init__(
self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None
):
Expand All @@ -44,6 +83,19 @@ def __init__(
)

def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
r"""
Forward pass
Parameters
----------
data: torch_geometric.data.Data
Input data object which should contain x_{i} features for each i in the selected_dimensions.
Returns
-------
torch_geometric.data.Data
Output data object.
"""
if not hasattr(data, "x_0"):
data.x_0 = data.x

Expand All @@ -58,6 +110,19 @@ def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:


class SetFeatureEncoder(AbstractInitFeaturesEncoder):
r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures.
Parameters
----------
in_channels: list(int)
Input dimensions for the features.
out_channels: list(int)
Output dimensions for the features.
proj_dropout: float
Dropout for the BaseEncoders.
selected_dimensions: list(int)
List of indexes to apply the BaseEncoders to.
"""
def __init__(
self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None
):
Expand Down Expand Up @@ -92,6 +157,19 @@ def __init__(
)

def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
r"""
Forward pass
Parameters
----------
data: torch_geometric.data.Data
Input data object which should contain x_{i} features for each i in the selected_dimensions.
Returns
-------
torch_geometric.data.Data
Output data object.
"""
if not hasattr(data, "x_0"):
data.x_0 = data.x

Expand Down

0 comments on commit f246f3d

Please sign in to comment.