diff --git a/code/datasets/utils.py b/code/datasets/utils.py index 5f47bec..dea78e5 100644 --- a/code/datasets/utils.py +++ b/code/datasets/utils.py @@ -84,18 +84,18 @@ def get_complex_connectivity(complex, max_rank, signed=False): ) except ValueError: if connectivity_info == "incidence": - connectivity[ - f"{connectivity_info}_{rank_idx}" - ] = generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], - n=practical_shape[rank_idx], + connectivity[f"{connectivity_info}_{rank_idx}"] = ( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx - 1], + n=practical_shape[rank_idx], + ) ) else: - connectivity[ - f"{connectivity_info}_{rank_idx}" - ] = generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], - n=practical_shape[rank_idx], + connectivity[f"{connectivity_info}_{rank_idx}"] = ( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx], + n=practical_shape[rank_idx], + ) ) """ Not needed right now according to TopoBenchmarkX diff --git a/code/experiments/utils/loggers.py b/code/experiments/utils/loggers.py index 9d07191..6393fd5 100644 --- a/code/experiments/utils/loggers.py +++ b/code/experiments/utils/loggers.py @@ -17,18 +17,19 @@ def get_wandb_logger( wandb_logger = WandbLogger(project=project_id, save_dir=save_dir) return wandb_logger + def update_wandb_logger( - wandb_logger, - task_name: TaskType, - save_dir="./lightning_logs", - model_name: str = None, - node_features: str = None, - run_id: str = None, - project_id: str = "mantra-dev", - ): + wandb_logger, + task_name: TaskType, + save_dir="./lightning_logs", + model_name: str = None, + node_features: str = None, + run_id: str = None, + project_id: str = "mantra-dev", +): wandb_logger.experiment.config["task"] = task_name.lower() wandb_logger.experiment.config["run_id"] = run_id wandb_logger.experiment.config["node_features"] = node_features if model_name is not None: - wandb_logger.experiment.config["model_name"] = model_name \ No newline at end of file + wandb_logger.experiment.config["model_name"] = model_name diff --git a/code/experiments/vis/us_cmap.py b/code/experiments/vis/us_cmap.py index de5071f..ecdb2f8 100644 --- a/code/experiments/vis/us_cmap.py +++ b/code/experiments/vis/us_cmap.py @@ -39,9 +39,9 @@ def scale_white_amount(rgb, percent): def register_name(new_name, color): try: # just a relabeling of the name - matplotlib.colors.ColorConverter.colors[ - new_name - ] = matplotlib.colors.ColorConverter.colors[color] + matplotlib.colors.ColorConverter.colors[new_name] = ( + matplotlib.colors.ColorConverter.colors[color] + ) except KeyError: if type(color) is str: color = get_rgba(color) @@ -66,9 +66,9 @@ def activate(): list_cmap.append(rgba) matplotlib.colors.ColorConverter.colors[name] = rgba for step in steps: - matplotlib.colors.ColorConverter.colors[ - f"{name}!{step * 100:.0f}" - ] = scale_white_amount(values, step) + matplotlib.colors.ColorConverter.colors[f"{name}!{step * 100:.0f}"] = ( + scale_white_amount(values, step) + ) matplotlib.cm.register_cmap( name="US", cmap=matplotlib.colors.ListedColormap(list_cmap) diff --git a/code/models/base.py b/code/models/base.py index 7f217ab..05238cb 100644 --- a/code/models/base.py +++ b/code/models/base.py @@ -34,7 +34,6 @@ def __init__( self.imbalance = self.imbalance / np.sum(self.imbalance) self.test_barycentric_subdivisions = 0 - def forward(self, batch): x = self.model(batch) return x