Skip to content

Commit

Permalink
working cell models
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 7, 2024
1 parent e0147f5 commit 1683cf2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion configs/model/cell/can.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ backbone:
heads: 4 # For now we stuck to out_channels//heads
concat: True
skip_connection: True
n_layers: 1
n_layers: 4
att_lift: False

loss:
Expand Down
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
defaults:
- _self_
- dataset: ZINC
- model: cell/cwn #hypergraph/unignn2 #
- model: cell/can #hypergraph/unignn2 #
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
Expand Down
20 changes: 10 additions & 10 deletions topobenchmarkx/models/wrappers/default_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ def __call__(self, batch):
"""Define logic for forward pass"""
model_out = {"labels": batch.y, "batch": batch.batch}
x_1 = self.backbone(
batch.x,
batch.x_1,
batch.adjacency_0.coalesce(),
batch.down_laplacian_1.coalesce(),
batch.up_laplacian_1.coalesce(),
x_0=batch.x_0,
x_1=batch.x_1,
adjacency_0=batch.adjacency_0.coalesce(),
down_laplacian_1=batch.down_laplacian_1.coalesce(),
up_laplacian_1=batch.up_laplacian_1.coalesce(),
)
x_0 = torch.sparse.mm(batch.incidence_1, x_1)

Expand Down Expand Up @@ -321,9 +321,9 @@ def __call__(self, batch):
x_0=batch.x_0,
x_1=batch.x_1,
x_2=batch.x_2,
neighborhood_0_to_1=batch.incidence_1.T,
neighborhood_1_to_1=batch.adjacency_1,
neighborhood_2_to_1=batch.incidence_2,
incidence_1_t=batch.incidence_1.T,
adjacency_0=batch.adjacency_1,
incidence_2=batch.incidence_2,
)

model_out["x_0"] = torch.mm(
Expand All @@ -346,8 +346,8 @@ def __call__(self, batch):
x_0, x_1, x_2 = self.backbone(
x_0=batch.x_0,
x_1=batch.x_1,
neighborhood_0_to_0=batch.adjacency_0,
neighborhood_1_to_2=batch.incidence_2.T,
adjacency_0=batch.adjacency_0,
incidence_2_t=batch.incidence_2.T,
)
model_out["x_0"] = x_0
model_out["x_1"] = x_1
Expand Down

0 comments on commit 1683cf2

Please sign in to comment.