Skip to content

Commit

Permalink
Adding a transform that allows to perform the barycentric subdivision…
Browse files Browse the repository at this point in the history
… on data's simplicial complexes. Then, another transformation can get the updated triangulation, to avoid overhead with previous models.
  • Loading branch information
rballeba committed Aug 29, 2024
1 parent e10d013 commit 6d088cb
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 5 deletions.
20 changes: 19 additions & 1 deletion code/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from datasets.utils import (
create_signals_on_data_if_needed,
append_signals,
get_complex_connectivity,
get_complex_connectivity, create_or_empty_signals_on_data, get_triangles_from_simplicial_complex,
)
from enum import Enum

from math_utils import recursive_barycentric_subdivision

NAME_TO_CLASS = {
"Klein bottle": 0,
"": 1,
Expand Down Expand Up @@ -58,6 +60,22 @@ def __call__(self, data):
return data


class BarycentricSubdivisionTransform:
def __init__(self, recursive_calls=1):
self.recursive_calls = recursive_calls

def __call__(self, data):
data.sc = recursive_barycentric_subdivision(data.sc, self.recursive_calls)
data = create_or_empty_signals_on_data(data)
return data


class UpdateTriangulationFromSimplicialComplexTransform:
def __call__(self, data):
data.triangulation = get_triangles_from_simplicial_complex(data)
return data


class SimplicialComplexDegreeTransform:
def __call__(self, data):
data = create_signals_on_data_if_needed(data)
Expand Down
21 changes: 17 additions & 4 deletions code/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import numpy as np
import torch
from toponetx import (
compute_bunch_normalized_matrices,
compute_x_laplacian_normalized_matrix,
)


def create_signals_on_data_if_needed(data):
Expand All @@ -12,6 +8,23 @@ def create_signals_on_data_if_needed(data):
return data


def create_or_empty_signals_on_data(data):
data.x = {}
return data


def get_triangles_from_simplicial_complex(data):
try:
sc = data.sc
except AttributeError:
raise AttributeError("Simplicial complex not found in data. Did you apply the SimplicialComplex transform"
"before trying to get triangulations?")
triangles = []
for triangle in sc.skeleton(2):
triangles.append(list(triangle))
return triangles


def generate_zero_sparse_connectivity(m, n):
# Function extracted from TopoBenchmarkX
"""Generate a zero sparse connectivity matrix.
Expand Down
42 changes: 42 additions & 0 deletions code/math_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from toponetx import SimplicialComplex


def barycentric_subdivision(K: SimplicialComplex) -> SimplicialComplex:
# Create a new SimplicialComplex to store the subdivision
Sd_K = SimplicialComplex()

new_simplices = {dim: set() for dim in range(K.dim + 1)}

# Add new vertices to Sd_K. Each simplex of Sd_K is a chain of simplices of K
for simplex in K.simplices:
new_simplices[0].add((simplex,))

# Give now an index to each simplex
simplex_to_index = {simplex[0]: i for i, simplex in enumerate(new_simplices[0])}

# Now, we add simplices from dimension 1 to K.dim
for dim in range(1, K.dim + 1):
# Get all simplices of the previous dimension, and try to add more simplices to the chain
previous_simplices = new_simplices[dim - 1]
for simplex_sub in previous_simplices:
last_simplex = simplex_sub[-1]
for simplex in K.simplices:
# Check if simplex is a face of simplex_sub
if last_simplex < simplex:
new_simplices[dim].add(simplex_sub + (simplex,))
# Now convert the simplices to indexes
all_simplices = []
for dim in range(K.dim + 1):
for simplex in new_simplices[dim]:
all_simplices.append([simplex_to_index[or_simplex] for or_simplex in simplex])
# Add the simplices to the new SimplicialComplex
print()
Sd_K.add_simplices_from(all_simplices)
return Sd_K, simplex_to_index


def recursive_barycentric_subdivision(K: SimplicialComplex, number_of_transformations: int) -> SimplicialComplex:
Sd_K = K
for _ in range(number_of_transformations):
Sd_K, _ = barycentric_subdivision(Sd_K)
return Sd_K

0 comments on commit 6d088cb

Please sign in to comment.