Skip to content

Commit

Permalink
fixed sccn
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 13, 2024
1 parent a9cc2ad commit 8ed265d
Show file tree
Hide file tree
Showing 67 changed files with 1,406 additions and 1,023 deletions.
25 changes: 10 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@ repos:
- id: check-added-large-files
args:
- --maxkb=2048
# - id: trailing-whitespace
- id: requirements-txt-fixer

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
#types_or: [ python, pyi, jupyter ]
#types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
#types_or: [ python, pyi, jupyter ]
#types_or: [ python, pyi ]
# - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: v0.4.4
# hooks:
# - id: ruff
# args: [ --fix ]
# - id: ruff-format

- repo: https://github.com/numpy/numpydoc
rev: v1.6.0
hooks:
- id: numpydoc-validation
# - repo: https://github.com/numpy/numpydoc
# rev: v1.6.0
# hooks:
# - id: numpydoc-validation
1 change: 0 additions & 1 deletion configs/model/simplicial/san.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ head_model:
out_channels: ${dataset.parameters.num_classes}
pooling_type: sum


loss:
_target_: topobenchmarkx.models.losses.loss.DefaultLoss
task: ${dataset.parameters.task}
Expand Down
4 changes: 2 additions & 2 deletions configs/model/simplicial/sccn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ feature_encoder:
backbone:
_target_: topomodelx.nn.simplicial.sccn.SCCN
channels: ${model.feature_encoder.out_channels}
max_rank: 1
max_rank: 2
n_layers: 1
update_func: "sigmoid"

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNWrapper
_partial_: true
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}
Expand Down
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
defaults:
- _self_
- dataset: PROTEINS_TU #us_country_demos
- model: hypergraph/allsettransformer #hypergraph/unignn2 #allsettransformer
- model: simplicial/sccn #hypergraph/unignn2 #allsettransformer
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
Expand Down
40 changes: 22 additions & 18 deletions custom_models/cell/cin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""CWN class."""

import torch
import torch.nn.functional as F
from topomodelx.nn.cell.cwn_layer import CWNLayer
import torch.nn as nn
import torch.nn.functional as F
from topomodelx.base.conv import Conv
from topomodelx.nn.cell.cwn_layer import CWNLayer
from torch_geometric.nn.models import MLP


Expand Down Expand Up @@ -65,7 +65,8 @@ def forward(
neighborhood_2_to_1,
neighborhood_0_to_1,
):
"""Forward computation through projection, convolutions, linear layers and average pooling.
"""Forward computation through projection, convolutions, linear layers
and average pooling.
Parameters
----------
Expand Down Expand Up @@ -192,15 +193,21 @@ def __init__(
self.conv_1_to_1 = (
conv_1_to_1
if conv_1_to_1 is not None
else _CWNDefaultFirstConv(in_channels_1, in_channels_2, out_channels)
else _CWNDefaultFirstConv(
in_channels_1, in_channels_2, out_channels
)
)
self.conv_0_to_1 = (
conv_0_to_1
if conv_0_to_1 is not None
else _CWNDefaultSecondConv(in_channels_0, in_channels_1, out_channels)
else _CWNDefaultSecondConv(
in_channels_0, in_channels_1, out_channels
)
)
self.aggregate_fn = (
aggregate_fn if aggregate_fn is not None else _CWNDefaultAggregate()
aggregate_fn
if aggregate_fn is not None
else _CWNDefaultAggregate()
)
self.update_fn = (
update_fn
Expand Down Expand Up @@ -325,11 +332,10 @@ def forward(


class _CWNDefaultFirstConv(nn.Module):
r"""
Default implementation of the first convolutional step in CWNLayer.
r"""Default implementation of the first convolutional step in CWNLayer.
The self.forward method of this module must be treated as
a protocol for the first convolutional step in CWN layer.
The self.forward method of this module must be treated as a protocol for
the first convolutional step in CWN layer.
"""

def __init__(
Expand Down Expand Up @@ -383,11 +389,10 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1):


class _CWNDefaultSecondConv(nn.Module):
r"""
Default implementation of the second convolutional step in CWNLayer.
r"""Default implementation of the second convolutional step in CWNLayer.
The self.forward method of this module must be treated as
a protocol for the second convolutional step in CWN layer.
The self.forward method of this module must be treated as a protocol for
the second convolutional step in CWN layer.
"""

def __init__(self, in_channels_0, out_channels) -> None:
Expand Down Expand Up @@ -417,11 +422,10 @@ def forward(self, x_0, neighborhood_0_to_1):


class _CWNDefaultAggregate(nn.Module):
r"""
Default implementation of an aggregation step in CWNLayer.
r"""Default implementation of an aggregation step in CWNLayer.
The self.forward method of this module must be treated as
a protocol for the aggregation step in CWN layer.
The self.forward method of this module must be treated as a protocol for
the aggregation step in CWN layer.
"""

def __init__(self) -> None:
Expand Down
39 changes: 28 additions & 11 deletions custom_models/hypergraph/edgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(
self.lins.append(nn.Linear(in_channels, hidden_channels))
self.normalizations.append(nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(nn.Linear(hidden_channels, hidden_channels))
self.lins.append(
nn.Linear(hidden_channels, hidden_channels)
)
self.normalizations.append(nn.BatchNorm1d(hidden_channels))
self.lins.append(nn.Linear(hidden_channels, out_channels))
elif Normalization == "ln":
Expand All @@ -65,7 +67,9 @@ def __init__(
self.lins.append(nn.Linear(in_channels, hidden_channels))
self.normalizations.append(nn.LayerNorm(hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(nn.Linear(hidden_channels, hidden_channels))
self.lins.append(
nn.Linear(hidden_channels, hidden_channels)
)
self.normalizations.append(nn.LayerNorm(hidden_channels))
self.lins.append(nn.Linear(hidden_channels, out_channels))
else:
Expand All @@ -78,7 +82,9 @@ def __init__(
self.lins.append(nn.Linear(in_channels, hidden_channels))
self.normalizations.append(nn.Identity())
for _ in range(num_layers - 2):
self.lins.append(nn.Linear(hidden_channels, hidden_channels))
self.lins.append(
nn.Linear(hidden_channels, hidden_channels)
)
self.normalizations.append(nn.Identity())
self.lins.append(nn.Linear(hidden_channels, out_channels))

Expand All @@ -88,7 +94,7 @@ def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
for normalization in self.normalizations:
if not (normalization.__class__.__name__ == "Identity"):
if normalization.__class__.__name__ != "Identity":
normalization.reset_parameters()

def forward(self, x):
Expand Down Expand Up @@ -245,7 +251,9 @@ def forward(self, X, vertex, edges, X0):


class JumpLinkConv(nn.Module):
def __init__(self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5):
def __init__(
self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5
):
super().__init__()
self.W = MLP(
in_features,
Expand Down Expand Up @@ -339,7 +347,10 @@ def forward(self, X, vertex, edges, X0):
) # [E, C], reduce is 'mean' here as default

deg_e = torch_scatter.scatter(
torch.ones(Xve.shape[0], device=Xve.device), edges, dim=-2, reduce="sum"
torch.ones(Xve.shape[0], device=Xve.device),
edges,
dim=-2,
reduce="sum",
)
Xe = torch.cat([Xe, torch.log(deg_e)[..., None]], -1)

Expand All @@ -350,7 +361,10 @@ def forward(self, X, vertex, edges, X0):
) # [N, C]

deg_v = torch_scatter.scatter(
torch.ones(Xev.shape[0], device=Xev.device), vertex, dim=-2, reduce="sum"
torch.ones(Xev.shape[0], device=Xev.device),
vertex,
dim=-2,
reduce="sum",
)
X = self.W3(torch.cat([Xv, X, X0, torch.log(deg_v)[..., None]], -1))

Expand All @@ -374,7 +388,7 @@ def __init__(
normalization="None",
AllSet_input_norm=False,
):
"""EDGNN
"""EDGNN.
Args:
num_features (int): number of input features
Expand All @@ -390,7 +404,6 @@ def __init__(
aggregate (str, optional): aggregation method. Defaults to 'add'.
normalization (str, optional): normalization method. Defaults to 'None'.
AllSet_input_norm (bool, optional): whether to normalize input features. Defaults to False.
"""
super().__init__()
act = {"Id": nn.Identity(), "relu": nn.ReLU(), "prelu": nn.PReLU()}
Expand All @@ -402,8 +415,12 @@ def __init__(
self.hidden_channels = self.in_channels

self.mlp1_layers = MLP_num_layers
self.mlp2_layers = MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers
self.mlp3_layers = MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers
self.mlp2_layers = (
MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers
)
self.mlp3_layers = (
MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers
)
self.nlayer = All_num_layers
self.edconv_type = edconv_type

Expand Down
Loading

0 comments on commit 8ed265d

Please sign in to comment.