Skip to content

Commit

Permalink
Started CI
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 6, 2024
1 parent 8621343 commit e0147f5
Show file tree
Hide file tree
Showing 48 changed files with 4,778 additions and 1,773 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# name: Linting

# on:
# push:
# branches: [ main,github-actions-test ]
# pull_request:
# branches: [ main ]

# jobs:
# ruff:
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v3
# - uses: chartboost/ruff-action@v1
38 changes: 15 additions & 23 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
default_language_version :
python : python3
default_language_version:
python: python3.10

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-byte-order-marker
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: check-yaml
Expand All @@ -17,25 +18,16 @@ repos:
- id: trailing-whitespace
- id: requirements-txt-fixer

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black-jupyter

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id : isort
args : ["--profile=black", "--filter-files"]

- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
hooks:
- id: blacken-docs
additional_dependencies: [black==23.3.0]
# - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: v0.1.14
# hooks:
# - id: ruff
# types_or: [ python, pyi, jupyter ]
# args: [ --fix ]
# - id: ruff-format
# types_or: [ python, pyi, jupyter ]

# - repo: https://github.com/pycqa/flake8
# rev: 6.0.0
# - repo: https://github.com/numpy/numpydoc
# rev: v1.6.0
# hooks:
# - id: flake8
# additional_dependencies: [flake8-docstrings, Flake8-pyproject]
# - id: numpydoc-validation
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
defaults:
- _self_
- dataset: ZINC
- model: hypergraph/unignn2 #simplicial/scn #
- model: cell/cwn #hypergraph/unignn2 #
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
Expand Down
42 changes: 22 additions & 20 deletions custom_models/cell/cin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn.functional as F

from topomodelx.nn.cell.cwn_layer import CWNLayer


Expand Down Expand Up @@ -104,18 +103,18 @@ def forward(
)

return x_0, x_1, x_2


#### LAYERs ####

"""Implementation of CWN layer from Bodnar et al.: Weisfeiler and Lehman Go Cellular: CW Networks."""

import torch.nn as nn
import torch.nn.functional as F

from topomodelx.base.conv import Conv
from torch_geometric.nn.models import MLP


class CWNLayer(nn.Module):
r"""Layer of a CW Network (CWN).
Expand Down Expand Up @@ -188,11 +187,10 @@ def __init__(
conv_0_to_1=None,
aggregate_fn=None,
update_fn=None,
eps=0.01
eps=0.01,
) -> None:
super().__init__()



self.conv_1_to_1 = (
conv_1_to_1
if conv_1_to_1 is not None
Expand All @@ -213,15 +211,15 @@ def __init__(
)
self.mlp_arrow = MLP(
[in_channels_1, in_channels_1, in_channels_1],
act='relu',
act="relu",
act_first=False,
norm=torch.nn.BatchNorm1d(out_channels),
# norm_kwargs=self.norm_kwargs,
)

self.mlp = MLP(
[in_channels_2 + in_channels_2, out_channels, out_channels],
act='relu',
act="relu",
act_first=False,
norm=torch.nn.BatchNorm1d(out_channels),
# norm_kwargs=self.norm_kwargs,
Expand Down Expand Up @@ -316,11 +314,15 @@ def forward(
)
x_convolved_1_to_1 = self.mlp_arrow(x_convolved_1_to_1)

#
x_convolved_0_to_1 = (1 + self.eps) * x_1 + self.conv_0_to_1(x_0, neighborhood_0_to_1)
#
x_convolved_0_to_1 = (1 + self.eps) * x_1 + self.conv_0_to_1(
x_0, neighborhood_0_to_1
)

x_aggregated = self.mlp(torch.cat([x_convolved_0_to_1, x_convolved_1_to_1], dim=-1))
#x_aggregated = self.aggregate_fn(x_convolved_1_to_1, x_convolved_0_to_1)
x_aggregated = self.mlp(
torch.cat([x_convolved_0_to_1, x_convolved_1_to_1], dim=-1)
)
# x_aggregated = self.aggregate_fn(x_convolved_1_to_1, x_convolved_0_to_1)
return self.update_fn(x_aggregated, x_1)


Expand All @@ -332,7 +334,9 @@ class _CWNDefaultFirstConv(nn.Module):
a protocol for the first convolutional step in CWN layer.
"""

def __init__(self, in_channels_1, in_channels_2, out_channels, eps: float = 0.) -> None:
def __init__(
self, in_channels_1, in_channels_2, out_channels, eps: float = 0.0
) -> None:
super().__init__()
self.conv_1_to_1 = Conv(
in_channels_1, out_channels, aggr_norm=False, update_func=None
Expand All @@ -343,14 +347,14 @@ def __init__(self, in_channels_1, in_channels_2, out_channels, eps: float = 0.)

self.mlp = MLP(
[in_channels_1 + in_channels_2, out_channels, out_channels],
act='relu',
act="relu",
act_first=False,
norm=torch.nn.BatchNorm1d(out_channels),
# norm_kwargs=self.norm_kwargs,
)

self.eps = torch.nn.Parameter(torch.Tensor([eps]))

def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1):
r"""Forward pass.
Expand All @@ -370,12 +374,10 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1):
torch.Tensor, shape = (n_{r}_cells, out_channels)
Updated representations on the r-cells.
"""
#
#
x_up = F.elu(self.conv_1_to_1(x_1, neighborhood_1_to_1))
x_up = (1 + self.eps) * x_1 + x_up



x_coboundary = F.elu(self.conv_2_to_1(x_2, neighborhood_2_to_1))
x_coboundary = (1 + self.eps) * x_1 + x_coboundary

Expand Down Expand Up @@ -467,4 +469,4 @@ def forward(self, x, x_prev=None):
torch.Tensor, shape = (n_{r}_cells, out_channels)
Updated representations on the r-cells.
"""
return F.elu(self.transform(x))
return F.elu(self.transform(x))
64 changes: 34 additions & 30 deletions custom_models/simplicial/sccnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""SCCNN implementation for complex classification."""
import torch

import torch
from topomodelx.nn.simplicial.sccnn_layer import SCCNNLayer


Expand Down Expand Up @@ -93,8 +93,9 @@ def forward(self, x_all, laplacian_all, incidence_all):
x_all = layer(x_all, laplacian_all, incidence_all)

return x_all

# Layer


# Layer
"""Simplicial Complex Convolutional Neural Network Layer."""
import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -223,7 +224,7 @@ def __init__(
torch.Tensor(
self.in_channels_1,
self.out_channels_1,
6*conv_order + 3,
6 * conv_order + 3,
)
)

Expand All @@ -236,16 +237,17 @@ def __init__(
torch.Tensor(
self.in_channels_2,
self.out_channels_2,
4*conv_order + 2, # in the future for arbitrary sc_order we should have this 6*conv_order + 3,
4 * conv_order
+ 2, # in the future for arbitrary sc_order we should have this 6*conv_order + 3,
)
)

elif sc_order == 2:
self.weight_2 = Parameter(
torch.Tensor(
self.in_channels_2,
self.out_channels_2,
4*conv_order + 2,
4 * conv_order + 2,
)
)

Expand Down Expand Up @@ -410,33 +412,35 @@ def forward(self, x_all, laplacian_all, incidence_all):
"""
Convolution in the node space
"""
#-----------Logic to obtain update for 0-cells --------
# -----------Logic to obtain update for 0-cells --------
# x_identity_0 = torch.unsqueeze(identity_0 @ x_0, 2)
# x_0_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_0)
# x_0_to_0 = torch.cat((x_identity_0, x_0_to_0), 2)

x_0_laplacian = self.chebyshev_conv(laplacian_0, self.conv_order, x_0)
x_0_to_0 = torch.cat([x_0.unsqueeze(2), x_0_laplacian], dim=2)
#-------------------
# -------------------

# x_1_to_0 = torch.mm(b1, x_1)
# x_1_to_0_identity = torch.unsqueeze(identity_0 @ x_1_to_0, 2)
# x_1_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_1_to_0)
# x_1_to_0 = torch.cat((x_1_to_0_identity, x_1_to_0), 2)

x_1_to_0_upper = torch.mm(b1, x_1)
x_1_to_0_laplacian = self.chebyshev_conv(laplacian_0, self.conv_order, x_1_to_0_upper)
x_1_to_0_laplacian = self.chebyshev_conv(
laplacian_0, self.conv_order, x_1_to_0_upper
)
x_1_to_0 = torch.cat([x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2)
#-------------------
# -------------------

x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2)

#-------------------
# -------------------
"""
Convolution in the edge space
"""

#-----------Logic to obtain update for 1-cells --------
# -----------Logic to obtain update for 1-cells --------
# x_identity_1 = torch.unsqueeze(identity_1 @ x_1, 2)
# x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1)
# x_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_1)
Expand All @@ -446,25 +450,25 @@ def forward(self, x_all, laplacian_all, incidence_all):
x_1_up = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1)
x_1_to_1 = torch.cat((x_1.unsqueeze(2), x_1_down, x_1_up), 2)

#-------------------
# -------------------

# x_0_to_1 = torch.mm(b1.T, x_0)
# x_0_to_1_identity = torch.unsqueeze(identity_1 @ x_0_to_1, 2)
# x_0_to_1 = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_to_1)
# x_0_to_1 = torch.cat((x_0_to_1_identity, x_0_to_1), 2)

# Lower projection
x_0_1_lower = torch.mm(b1.T, x_0)
x_0_1_lower = torch.mm(b1.T, x_0)

# Calculate lowwer chebyshev_conv
x_0_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_1_lower)

# Calculate upper chebyshev_conv (Note: in case of signed incidence should be always zero)
x_0_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_0_1_lower)

# Concatenate output of filters
x_0_to_1 = torch.cat([x_0_1_lower.unsqueeze(2),x_0_1_down, x_0_1_up], dim=2)
#-------------------
x_0_to_1 = torch.cat([x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2)
# -------------------

# x_2_to_1 = torch.mm(b2, x_2)
# x_2_to_1_identity = torch.unsqueeze(identity_1 @ x_2_to_1, 2)
Expand All @@ -475,20 +479,20 @@ def forward(self, x_all, laplacian_all, incidence_all):

# Calculate lowwer chebyshev_conv (Note: In case of signed incidence should be always zero)
x_2_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_2_1_upper)

# Calculate upper chebyshev_conv
x_2_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_1_upper)

x_2_to_1 = torch.cat([x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2)

#-------------------
# -------------------
x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2)

"""
convolution in the face (triangle) space, depending on the SC order,
the exact form maybe a little different
"""
#-------------------Logic to obtain update for 2-cells --------
# -------------------Logic to obtain update for 2-cells --------
# x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2)

# if self.sc_order == 2:
Expand All @@ -502,7 +506,7 @@ def forward(self, x_all, laplacian_all, incidence_all):
x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2)
x_2_to_2 = torch.cat((x_2.unsqueeze(2), x_2_down, x_2_up), 2)

#-------------------
# -------------------

# x_1_to_2 = torch.mm(b2.T, x_1)
# x_1_to_2_identity = torch.unsqueeze(identity_2 @ x_1_to_2, 2)
Expand All @@ -525,13 +529,13 @@ def forward(self, x_all, laplacian_all, incidence_all):

# x_3_to_2 = torch.cat([x_3_2_upper.unsueeze(2), x_3_2_down, x_3_2_up], dim=2)

#-------------------
# -------------------

x_2_all = torch.cat([x_1_to_2, x_2_to_2], dim=2)
# The final version of this model should have the following line
# x_2_all = torch.cat([x_1_to_2, x_2_to_2, x_3_to_2], dim=2)

#-------------------
# -------------------

# Need to check that this einsums are correct
y_0 = torch.einsum("nik,iok->no", x_0_all, self.weight_0)
Expand All @@ -541,4 +545,4 @@ def forward(self, x_all, laplacian_all, incidence_all):
if self.update_func is None:
return y_0, y_1, y_2

return self.update(y_0), self.update(y_1), self.update(y_2)
return self.update(y_0), self.update(y_1), self.update(y_2)
Loading

0 comments on commit e0147f5

Please sign in to comment.