From 1683cf295dcfe4440eba67f0e706accca6528ea0 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 7 May 2024 21:42:04 +0200 Subject: [PATCH] working cell models --- configs/model/cell/can.yaml | 2 +- configs/train.yaml | 2 +- .../models/wrappers/default_wrapper.py | 20 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index 85ca8d84..7ce1eeae 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -17,7 +17,7 @@ backbone: heads: 4 # For now we stuck to out_channels//heads concat: True skip_connection: True - n_layers: 1 + n_layers: 4 att_lift: False loss: diff --git a/configs/train.yaml b/configs/train.yaml index 01144a49..cb72c21b 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - dataset: ZINC - - model: cell/cwn #hypergraph/unignn2 # + - model: cell/can #hypergraph/unignn2 # - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/topobenchmarkx/models/wrappers/default_wrapper.py b/topobenchmarkx/models/wrappers/default_wrapper.py index c05c75eb..eb6db0aa 100755 --- a/topobenchmarkx/models/wrappers/default_wrapper.py +++ b/topobenchmarkx/models/wrappers/default_wrapper.py @@ -274,11 +274,11 @@ def __call__(self, batch): """Define logic for forward pass""" model_out = {"labels": batch.y, "batch": batch.batch} x_1 = self.backbone( - batch.x, - batch.x_1, - batch.adjacency_0.coalesce(), - batch.down_laplacian_1.coalesce(), - batch.up_laplacian_1.coalesce(), + x_0=batch.x_0, + x_1=batch.x_1, + adjacency_0=batch.adjacency_0.coalesce(), + down_laplacian_1=batch.down_laplacian_1.coalesce(), + up_laplacian_1=batch.up_laplacian_1.coalesce(), ) x_0 = torch.sparse.mm(batch.incidence_1, x_1) @@ -321,9 +321,9 @@ def __call__(self, batch): x_0=batch.x_0, x_1=batch.x_1, x_2=batch.x_2, - neighborhood_0_to_1=batch.incidence_1.T, - neighborhood_1_to_1=batch.adjacency_1, - neighborhood_2_to_1=batch.incidence_2, + incidence_1_t=batch.incidence_1.T, + adjacency_0=batch.adjacency_1, + incidence_2=batch.incidence_2, ) model_out["x_0"] = torch.mm( @@ -346,8 +346,8 @@ def __call__(self, batch): x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, - neighborhood_0_to_0=batch.adjacency_0, - neighborhood_1_to_2=batch.incidence_2.T, + adjacency_0=batch.adjacency_0, + incidence_2_t=batch.incidence_2.T, ) model_out["x_0"] = x_0 model_out["x_1"] = x_1