Skip to content

Commit

Permalink
updated demos
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 7, 2024
1 parent 1683cf2 commit e1fadcf
Show file tree
Hide file tree
Showing 20 changed files with 65 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: check-added-large-files
args:
- --maxkb=2048
- id: trailing-whitespace
# - id: trailing-whitespace
- id: requirements-txt-fixer

# - repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "ConcatentionLifting"
transform_name: "ProjectionSum"
transform_type: null
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "ProjectionLifting"
transform_name: "ConcatentionLifting"
transform_type: null
3 changes: 2 additions & 1 deletion configs/dataset/us_county_demos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ _target_: topobenchmarkx.io.load.loaders.GraphLoader

defaults:
- transforms: ${get_default_transform:graph,${model}}


# Data definition
parameters:
Expand All @@ -18,6 +18,7 @@ parameters:
num_classes: 1
task: regression
task_variable: 'Election' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']
force_reload: True
loss_type: mse
monitor_metric: mse
task_level: node
Expand Down
9 changes: 0 additions & 9 deletions configs/model/cell/can.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,12 @@ loss:
task: ${dataset.parameters.task}
loss_type: ${dataset.parameters.loss_type}

# evaluator:
# _target_: topobenchmarkx.evaluators.evaluator.Evaluator
# metrics: [acc, rocauc]

readout:
_target_: topobenchmarkx.models.readouts.default_readouts.GNNBatchReadOut
task_level: ${dataset.parameters.task_level}
in_channels: ${parameter_multiplication:${model.backbone.out_channels},${model.backbone.heads}}
out_channels: ${dataset.parameters.num_classes}

# readout_workaround:
# _target_: topobenchmarkx.models.readout_workaround.ReadOutWorkaround
# backbone_outputs: ["x_0"]


backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.CANWrapper
_partial_: true
Expand Down
2 changes: 1 addition & 1 deletion configs/model/cell/cwn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ backbone:
in_channels_1: ${model.feature_encoder.out_channels}
in_channels_2: ${model.feature_encoder.out_channels}
hid_channels: ${model.feature_encoder.out_channels}
n_layers: 4
n_layers: 2

loss:
_target_: topobenchmarkx.models.losses.loss.DefaultLoss
Expand Down
2 changes: 1 addition & 1 deletion configs/model/cell/cwn_dcm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ feature_encoder:
backbone:
_target_: custom_models.cell.cwn_dcm.CWNDCM
in_channels: ${model.feature_encoder.out_channels}
n_layers: 4
n_layers: 1
dropout: 0.0

loss:
Expand Down
2 changes: 1 addition & 1 deletion configs/model/graph/gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ backbone:
_target_: torch_geometric.nn.models.GCN
in_channels: ${model.feature_encoder.out_channels}
hidden_channels: ${model.feature_encoder.out_channels}
num_layers: 3
num_layers: 1
dropout: 0.0
act: relu

Expand Down
4 changes: 2 additions & 2 deletions configs/model/simplicial/sccn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ _target_: topobenchmarkx.models.network_module.NetworkModule
feature_encoder:
_target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder
in_channels: ${infer_in_channels:${dataset}} # ${dataset.parameters.num_features}
out_channels: 128
out_channels: 32

backbone:
_target_: topomodelx.nn.simplicial.sccn.SCCN
channels: ${model.feature_encoder.out_channels}
max_rank: 1
n_layers: 1
n_layers: 2
update_func: "sigmoid"

loss:
Expand Down
2 changes: 1 addition & 1 deletion configs/model/simplicial/scn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ backbone:
in_channels_0: ${model.feature_encoder.out_channels}
in_channels_1: ${model.feature_encoder.out_channels}
in_channels_2: ${model.feature_encoder.out_channels}
n_layers: 1
n_layers: 3

loss:
_target_: topobenchmarkx.models.losses.loss.DefaultLoss
Expand Down
4 changes: 2 additions & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- dataset: ZINC
- model: cell/can #hypergraph/unignn2 #
- dataset: us_county_demos
- model: simplicial/scn #hypergraph/unignn2 #
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
Expand Down
12 changes: 6 additions & 6 deletions topobenchmarkx/data/dataloader_fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def collate_fn(batch):
+ running_idx[f"cell_running_idx_number_{cell_dim}"]
).long()

running_idx[f"cell_running_idx_number_{cell_dim}"] += (
current_number_of_cells # current_number_of_nodes
)
running_idx[
f"cell_running_idx_number_{cell_dim}"
] += current_number_of_cells # current_number_of_nodes

elif x_key == "x_hyperedges":
cell_dim = x_key.split("_")[1]
Expand All @@ -112,9 +112,9 @@ def collate_fn(batch):
+ running_idx[f"cell_running_idx_number_{cell_dim}"]
).long()

running_idx[f"cell_running_idx_number_{cell_dim}"] += (
current_number_of_hyperedges
)
running_idx[
f"cell_running_idx_number_{cell_dim}"
] += current_number_of_hyperedges
else:
# Function Batch.from_data_list creates a running index automatically
pass
Expand Down
14 changes: 12 additions & 2 deletions topobenchmarkx/data/us_county_demos_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs

from topobenchmarkx.io.load.cornel_dataset import load_us_county_demos
from topobenchmarkx.io.load.us_county_demos import load_us_county_demos
from topobenchmarkx.io.load.download_utils import download_file_from_drive
from topobenchmarkx.io.load.split_utils import random_splitting

Expand Down Expand Up @@ -92,6 +92,14 @@ def __init__(
data.val_mask = torch.from_numpy(splits["valid"])
data.test_mask = torch.from_numpy(splits["test"])

# Standardize the node features respecting train mask
data.x = (data.x - data.x[data.train_mask].mean(0)) / data.x[
data.train_mask
].std(0)
data.y = (data.y - data.y[data.train_mask].mean(0)) / data.y[
data.train_mask
].std(0)

# Assign data object to self.data, to make it be prodessed by Dataset class
self.data, self.slices = self.collate([data])

Expand Down Expand Up @@ -154,7 +162,9 @@ def process(self) -> None:
Returns:
None
"""
data = load_us_county_demos(self.raw_dir, year=self.parameters.year)
data = load_us_county_demos(
self.raw_dir, year=self.parameters.year, y_col=self.parameters.task_variable
)

data = data if self.pre_transform is None else self.pre_transform(data)
self.save([data], self.processed_paths[0])
Expand Down
5 changes: 4 additions & 1 deletion topobenchmarkx/io/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,10 @@ def load(self) -> CustomDataset:
)

if self.transforms_config is not None:
dataset = Preprocessor(data_dir, dataset, self.transforms_config)
# force_reload=True because in this datasets many variables can be trated as y
dataset = Preprocessor(
data_dir, dataset, self.transforms_config, force_reload=True
)

# We need to map original dataset into custom one to make batching work
dataset = CustomDataset([dataset[0]])
Expand Down
12 changes: 10 additions & 2 deletions topobenchmarkx/io/load/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@ class Preprocessor(torch_geometric.data.InMemoryDataset):
Additional arguments.
"""

def __init__(self, data_dir, data_list, transforms_config, **kwargs):
def __init__(
self, data_dir, data_list, transforms_config, force_reload=False, **kwargs
):
if isinstance(data_list, torch_geometric.data.Dataset):
data_list = [data_list.get(idx) for idx in range(len(data_list))]
elif isinstance(data_list, torch_geometric.data.Data):
data_list = [data_list]
self.data_list = data_list
pre_transform = self.instantiate_pre_transform(data_dir, transforms_config)
super().__init__(self.processed_data_dir, None, pre_transform, **kwargs)
super().__init__(
self.processed_data_dir,
None,
pre_transform,
force_reload=force_reload,
**kwargs,
)
self.save_transform_parameters()
self.load(self.processed_paths[0])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_geometric


def load_us_county_demos(path, year=2012):
def load_us_county_demos(path, year=2012, y_col="Election"):
edges_df = pd.read_csv(f"{path}/county_graph.csv")
stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1")

Expand Down Expand Up @@ -83,7 +83,7 @@ def load_us_county_demos(path, year=2012):
stat = stat.drop(columns=["DEM", "GOP", "FIPS"])

# Prediction col
y_col = "Election" # TODO: Define through config file

x_col = list(set(stat.columns).difference(set([y_col])))

stat["MedianIncome"] = (
Expand Down
20 changes: 12 additions & 8 deletions topobenchmarkx/models/wrappers/default_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,18 @@ def normalize_matrix(matrix):
)

# Propagate signal down
model_out["x_2"] = x_2
model_out["x_1"] = self.norm_1(
x_1 + self.agg_conv_1(model_out["x_2"], batch.incidence_2),
batch.batch_1,
)
model_out["x_0"] = self.norm_2(
x_0 + self.agg_conv_2(model_out["x_1"], batch.incidence_1), batch.batch
)
# model_out["x_2"] = x_2
# model_out["x_1"] = x_1
# model_out["x_0"] = x_0

# model_out["x_2"] = x_2
# model_out["x_1"] = self.norm_1(
# x_1 + self.agg_conv_1(model_out["x_2"], batch.incidence_2),
# batch.batch_1,
# )
# model_out["x_0"] = self.norm_2(
# x_0 + self.agg_conv_2(model_out["x_1"], batch.incidence_1), batch.batch
# )

return model_out

Expand Down
5 changes: 2 additions & 3 deletions topobenchmarkx/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from topobenchmarkx.transforms.feature_liftings.feature_liftings import (
ConcatentionLifting,
ProjectionLifting,
ProjectionSum,
SetLifting,
)
from topobenchmarkx.transforms.liftings.graph2cell import CellCyclesLifting
Expand All @@ -30,7 +30,6 @@
)

TRANSFORMS = {
###
# Graph -> Hypergraph
"HypergraphKHopLifting": HypergraphKHopLifting,
"HypergraphKNearestNeighborsLifting": HypergraphKNearestNeighborsLifting,
Expand All @@ -40,7 +39,7 @@
# Graph -> Cell Complex
"CellCyclesLifting": CellCyclesLifting,
# Feature Liftings
"ProjectionLifting": ProjectionLifting,
"ProjectionSum": ProjectionSum,
"ConcatentionLifting": ConcatentionLifting,
"SetLifting": SetLifting,
# Data Manipulations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch_geometric


class ProjectionLifting(torch_geometric.transforms.BaseTransform):
class ProjectionSum(torch_geometric.transforms.BaseTransform):
r"""Lifts r-cell features to r+1-cells by projection.
Parameters
Expand Down Expand Up @@ -33,7 +33,7 @@ def lift_features(
if f"x_{elem}" not in data:
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1
data["x_" + elem] = torch.matmul(
data["incidence_" + elem].t(),
abs(data["incidence_" + elem].t()),
data[f"x_{idx_to_project}"],
)
return data
Expand Down
4 changes: 2 additions & 2 deletions topobenchmarkx/transforms/liftings/graph_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from topobenchmarkx.transforms.feature_liftings.feature_liftings import (
ConcatentionLifting,
ProjectionLifting,
ProjectionSum,
SetLifting,
)

# Implemented Feature Liftings
FEATURE_LIFTINGS = {
"projection": ProjectionLifting,
"projection": ProjectionSum,
"concatenation": ConcatentionLifting,
"set": SetLifting,
}
Expand Down

0 comments on commit e1fadcf

Please sign in to comment.