diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f376abc9..fa73c2ad 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/topobenchmarkx/models/encoders/perceiver.py b/topobenchmarkx/models/encoders/perceiver.py index 458081d4..ff39fc4e 100644 --- a/topobenchmarkx/models/encoders/perceiver.py +++ b/topobenchmarkx/models/encoders/perceiver.py @@ -64,6 +64,17 @@ def cached_fn(*args, _cache=True, **kwargs): class PreNorm(nn.Module): + r"""Class to wrap together LayerNorm and a specified function. + + Parameters + ---------- + dim: int + Size of the dimension to normalize. + fn: torch.nn.Module + Function after LayerNorm. + context_dim: int + Size of the context to normalize. + """ def __init__(self, dim, fn, context_dim=None): super().__init__() self.fn = fn @@ -71,6 +82,20 @@ def __init__(self, dim, fn, context_dim=None): self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, **kwargs): + r"""Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor. + kwargs: dict + Dictionary of keyword arguments. + + Returns + ------- + torch.Tensor + Output tensor. + """ x = self.norm(x) if exists(self.norm_context): @@ -82,12 +107,28 @@ def forward(self, x, **kwargs): class GEGLU(nn.Module): + r"""GEGLU activation function.""" def forward(self, x): + r"""Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor. + """ x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) - class FeedForward(nn.Module): + r"""Feedforward network. + + Parameters + ---------- + dim: int + Size of the input dimension. + mult: int + Multiplier for the hidden dimension. + """ def __init__(self, dim, mult=4): super().__init__() self.net = nn.Sequential( @@ -95,10 +136,30 @@ def __init__(self, dim, mult=4): ) def forward(self, x): + r"""Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor. + """ return self.net(x) class Attention(nn.Module): + r"""Attention function. + + Parameters + ---------- + query_dim: int + Size of the query dimension. + context_dim: int + Size of the context dimension. + heads: int + Number of heads. + dim_head: int + Size for each head. + """ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads @@ -111,6 +172,22 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): self.to_out = nn.Linear(inner_dim, query_dim) def forward(self, x, context=None, mask=None): + r"""Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor. + context: torch.Tensor + Context tensor. + mask: torch.Tensor + Mask for attention calculation purposes. + + Returns + ------- + torch.Tensor + Output tensor. + """ h = self.heads q = self.to_q(x) @@ -139,12 +216,35 @@ def forward(self, x, context=None, mask=None): class Perceiver(nn.Module): + r"""Perceiver model. + + Parameters + ---------- + depth: int + Number of layers to add to the model. + dim: int + Size of the input dimension. + num_latents: int + Number of latent vectors. + cross_heads: int + Number of heads for cross attention. + latent_heads: int + Number of heads for latent attention. + cross_dim_head: int + Size of the cross attention head. + latent_dim_head: int + Size of the latent attention head. + weight_tie_layers: bool + Whether to tie the weights of the layers. + decoder_ff: bool + Whether to use a feedforward network in the decoder. + """ def __init__( self, *, depth, dim, - logits_dim=None, + # logits_dim=None, num_latents=1, cross_heads=1, latent_heads=8, @@ -203,11 +303,22 @@ def __init__( PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None ) - self.to_logits = ( - nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() - ) + # self.to_logits = ( + # nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() + # ) def forward(self, data, mask=None, queries=None): + r"""Forward pass. + + Parameters + ---------- + data: torch.Tensor + Input tensor. + mask: torch.Tensor + Mask for attention calculation purposes. + queries: torch.Tensor + Queries tensor. + """ b, *_, device = *data.shape, data.device x = repeat(self.latents, "n d -> b n d", b=b) diff --git a/topobenchmarkx/models/readouts/default_readouts.py b/topobenchmarkx/models/readouts/default_readouts.py index 71a19a0a..9db97518 100755 --- a/topobenchmarkx/models/readouts/default_readouts.py +++ b/topobenchmarkx/models/readouts/default_readouts.py @@ -5,6 +5,19 @@ class GNNBatchReadOut(AbstractReadOut): + r"""Readout layer for GNNs that operates on the batch level. + + Parameters + ---------- + in_channels: int + Input dimension. + out_channels: int + Output dimension. + task_level: str + Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. + pooling_type: str + Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. + """ def __init__( self, in_channels: int, @@ -22,6 +35,18 @@ def __init__( self.pooling_type = pooling_type def forward(self, model_out: dict): + r"""Forward pass. + + Parameters + ---------- + model_out: dict + Dictionary containing the model output. + + Returns + ------- + dict + Dictionary containing the updated model output. Resulting key is "logits". + """ x = model_out["x_0"] batch = model_out["batch"] if self.task_level == "graph": diff --git a/topobenchmarkx/utils/config_resolvers.py b/topobenchmarkx/utils/config_resolvers.py index 37e5072b..e7d13a0a 100644 --- a/topobenchmarkx/utils/config_resolvers.py +++ b/topobenchmarkx/utils/config_resolvers.py @@ -1,4 +1,23 @@ def get_default_transform(data_domain, model): + r"""Get default transform for a given data domain and model. + + Parameters + ---------- + data_domain: str + Data domain. + model: str + Model name. Should be in the format "model_domain/name". + + Returns + ------- + str + Default transform. + + Raises + ------ + ValueError + If the combination of data_domain and model is invalid. + """ model_domain = model.split("/")[0] if data_domain == model_domain: return "identity" @@ -11,6 +30,25 @@ def get_default_transform(data_domain, model): def get_monitor_metric(task, loss): + r"""Get monitor metric for a given task and loss. + + Parameters + ---------- + task: str + Task, either "classification" or "regression". + loss: str + Name of the loss function. + + Returns + ------- + str + Monitor metric. + + Raises + ------ + ValueError + If the task is invalid. + """ if task == "classification": return "val/accuracy" elif task == "regression": @@ -20,6 +58,23 @@ def get_monitor_metric(task, loss): def get_monitor_mode(task): + r"""Get monitor mode for a given task. + + Parameters + ---------- + task: str + Task, either "classification" or "regression". + + Returns + ------- + str + Monitor mode, either "max" or "min". + + Raises + ------ + ValueError + If the task is invalid. + """ if task == "classification": return "max" elif task == "regression": @@ -29,7 +84,33 @@ def get_monitor_mode(task): def infer_in_channels(dataset): + r"""Infer the number of input channels for a given dataset. + + Parameters + ---------- + dataset: torch_geometric.data.Dataset + Input dataset. + + Returns + ------- + list + List with dimensions of the input channels. + """ def find_complex_lifting(dataset): + r"""Find if there is a complex lifting in the dataset. + + Parameters + ---------- + dataset: torch_geometric.data.Dataset + Input dataset. + + Returns + ------- + bool + True if there is a complex lifting, False otherwise. + str + Name of the complex lifting, if it exists. + """ if "transforms" not in dataset: return False, None complex_transforms = [ @@ -43,6 +124,20 @@ def find_complex_lifting(dataset): return False, None def check_for_type_feature_lifting(dataset, lifting): + r"""Check the type of feature lifting in the dataset. + + Parameters + ---------- + dataset: torch_geometric.data.Dataset + Input dataset. + lifting: str + Name of the complex lifting. + + Returns + ------- + str + Type of feature lifting. + """ lifting_params_keys = dataset.transforms[lifting].keys() if "feature_lifting" in lifting_params_keys: feature_lifting = dataset.transforms[lifting]["feature_lifting"]