Skip to content

Commit

Permalink
added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Coerulatus committed May 8, 2024
1 parent fe0e33f commit 81a098a
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_language_version:
python: python3.10
python: python3.11

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
121 changes: 116 additions & 5 deletions topobenchmarkx/models/encoders/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,38 @@ def cached_fn(*args, _cache=True, **kwargs):


class PreNorm(nn.Module):
r"""Class to wrap together LayerNorm and a specified function.
Parameters
----------
dim: int
Size of the dimension to normalize.
fn: torch.nn.Module
Function after LayerNorm.
context_dim: int
Size of the context to normalize.
"""
def __init__(self, dim, fn, context_dim=None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

def forward(self, x, **kwargs):
r"""Forward pass.
Parameters
----------
x: torch.Tensor
Input tensor.
kwargs: dict
Dictionary of keyword arguments.
Returns
-------
torch.Tensor
Output tensor.
"""
x = self.norm(x)

if exists(self.norm_context):
Expand All @@ -82,23 +107,59 @@ def forward(self, x, **kwargs):


class GEGLU(nn.Module):
r"""GEGLU activation function."""
def forward(self, x):
r"""Forward pass.
Parameters
----------
x: torch.Tensor
Input tensor.
"""
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


class FeedForward(nn.Module):
r"""Feedforward network.
Parameters
----------
dim: int
Size of the input dimension.
mult: int
Multiplier for the hidden dimension.
"""
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim)
)

def forward(self, x):
r"""Forward pass.
Parameters
----------
x: torch.Tensor
Input tensor.
"""
return self.net(x)


class Attention(nn.Module):
r"""Attention function.
Parameters
----------
query_dim: int
Size of the query dimension.
context_dim: int
Size of the context dimension.
heads: int
Number of heads.
dim_head: int
Size for each head.
"""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -111,6 +172,22 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
self.to_out = nn.Linear(inner_dim, query_dim)

def forward(self, x, context=None, mask=None):
r"""Forward pass.
Parameters
----------
x: torch.Tensor
Input tensor.
context: torch.Tensor
Context tensor.
mask: torch.Tensor
Mask for attention calculation purposes.
Returns
-------
torch.Tensor
Output tensor.
"""
h = self.heads

q = self.to_q(x)
Expand Down Expand Up @@ -139,12 +216,35 @@ def forward(self, x, context=None, mask=None):


class Perceiver(nn.Module):
r"""Perceiver model.
Parameters
----------
depth: int
Number of layers to add to the model.
dim: int
Size of the input dimension.
num_latents: int
Number of latent vectors.
cross_heads: int
Number of heads for cross attention.
latent_heads: int
Number of heads for latent attention.
cross_dim_head: int
Size of the cross attention head.
latent_dim_head: int
Size of the latent attention head.
weight_tie_layers: bool
Whether to tie the weights of the layers.
decoder_ff: bool
Whether to use a feedforward network in the decoder.
"""
def __init__(
self,
*,
depth,
dim,
logits_dim=None,
# logits_dim=None,
num_latents=1,
cross_heads=1,
latent_heads=8,
Expand Down Expand Up @@ -203,11 +303,22 @@ def __init__(
PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
)

self.to_logits = (
nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()
)
# self.to_logits = (
# nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()
# )

def forward(self, data, mask=None, queries=None):
r"""Forward pass.
Parameters
----------
data: torch.Tensor
Input tensor.
mask: torch.Tensor
Mask for attention calculation purposes.
queries: torch.Tensor
Queries tensor.
"""
b, *_, device = *data.shape, data.device

x = repeat(self.latents, "n d -> b n d", b=b)
Expand Down
25 changes: 25 additions & 0 deletions topobenchmarkx/models/readouts/default_readouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@


class GNNBatchReadOut(AbstractReadOut):
r"""Readout layer for GNNs that operates on the batch level.
Parameters
----------
in_channels: int
Input dimension.
out_channels: int
Output dimension.
task_level: str
Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings.
pooling_type: str
Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding.
"""
def __init__(
self,
in_channels: int,
Expand All @@ -22,6 +35,18 @@ def __init__(
self.pooling_type = pooling_type

def forward(self, model_out: dict):
r"""Forward pass.
Parameters
----------
model_out: dict
Dictionary containing the model output.
Returns
-------
dict
Dictionary containing the updated model output. Resulting key is "logits".
"""
x = model_out["x_0"]
batch = model_out["batch"]
if self.task_level == "graph":
Expand Down
95 changes: 95 additions & 0 deletions topobenchmarkx/utils/config_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
def get_default_transform(data_domain, model):
r"""Get default transform for a given data domain and model.
Parameters
----------
data_domain: str
Data domain.
model: str
Model name. Should be in the format "model_domain/name".
Returns
-------
str
Default transform.
Raises
------
ValueError
If the combination of data_domain and model is invalid.
"""
model_domain = model.split("/")[0]
if data_domain == model_domain:
return "identity"
Expand All @@ -11,6 +30,25 @@ def get_default_transform(data_domain, model):


def get_monitor_metric(task, loss):
r"""Get monitor metric for a given task and loss.
Parameters
----------
task: str
Task, either "classification" or "regression".
loss: str
Name of the loss function.
Returns
-------
str
Monitor metric.
Raises
------
ValueError
If the task is invalid.
"""
if task == "classification":
return "val/accuracy"
elif task == "regression":
Expand All @@ -20,6 +58,23 @@ def get_monitor_metric(task, loss):


def get_monitor_mode(task):
r"""Get monitor mode for a given task.
Parameters
----------
task: str
Task, either "classification" or "regression".
Returns
-------
str
Monitor mode, either "max" or "min".
Raises
------
ValueError
If the task is invalid.
"""
if task == "classification":
return "max"
elif task == "regression":
Expand All @@ -29,7 +84,33 @@ def get_monitor_mode(task):


def infer_in_channels(dataset):
r"""Infer the number of input channels for a given dataset.
Parameters
----------
dataset: torch_geometric.data.Dataset
Input dataset.
Returns
-------
list
List with dimensions of the input channels.
"""
def find_complex_lifting(dataset):
r"""Find if there is a complex lifting in the dataset.
Parameters
----------
dataset: torch_geometric.data.Dataset
Input dataset.
Returns
-------
bool
True if there is a complex lifting, False otherwise.
str
Name of the complex lifting, if it exists.
"""
if "transforms" not in dataset:
return False, None
complex_transforms = [
Expand All @@ -43,6 +124,20 @@ def find_complex_lifting(dataset):
return False, None

def check_for_type_feature_lifting(dataset, lifting):
r"""Check the type of feature lifting in the dataset.
Parameters
----------
dataset: torch_geometric.data.Dataset
Input dataset.
lifting: str
Name of the complex lifting.
Returns
-------
str
Type of feature lifting.
"""
lifting_params_keys = dataset.transforms[lifting].keys()
if "feature_lifting" in lifting_params_keys:
feature_lifting = dataset.transforms[lifting]["feature_lifting"]
Expand Down

0 comments on commit 81a098a

Please sign in to comment.