Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Sep 13, 2024
1 parent af9db6b commit 0fea8dc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
20 changes: 10 additions & 10 deletions code/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,18 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError:
if connectivity_info == "incidence":
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1],
n=practical_shape[rank_idx],
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1],
n=practical_shape[rank_idx],
)
)
else:
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx],
n=practical_shape[rank_idx],
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx],
n=practical_shape[rank_idx],
)
)
"""
Not needed right now according to TopoBenchmarkX
Expand Down
19 changes: 10 additions & 9 deletions code/experiments/utils/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ def get_wandb_logger(
wandb_logger = WandbLogger(project=project_id, save_dir=save_dir)
return wandb_logger


def update_wandb_logger(
wandb_logger,
task_name: TaskType,
save_dir="./lightning_logs",
model_name: str = None,
node_features: str = None,
run_id: str = None,
project_id: str = "mantra-dev",
):
wandb_logger,
task_name: TaskType,
save_dir="./lightning_logs",
model_name: str = None,
node_features: str = None,
run_id: str = None,
project_id: str = "mantra-dev",
):
wandb_logger.experiment.config["task"] = task_name.lower()
wandb_logger.experiment.config["run_id"] = run_id
wandb_logger.experiment.config["node_features"] = node_features

if model_name is not None:
wandb_logger.experiment.config["model_name"] = model_name
wandb_logger.experiment.config["model_name"] = model_name
12 changes: 6 additions & 6 deletions code/experiments/vis/us_cmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def scale_white_amount(rgb, percent):
def register_name(new_name, color):
try:
# just a relabeling of the name
matplotlib.colors.ColorConverter.colors[
new_name
] = matplotlib.colors.ColorConverter.colors[color]
matplotlib.colors.ColorConverter.colors[new_name] = (
matplotlib.colors.ColorConverter.colors[color]
)
except KeyError:
if type(color) is str:
color = get_rgba(color)
Expand All @@ -66,9 +66,9 @@ def activate():
list_cmap.append(rgba)
matplotlib.colors.ColorConverter.colors[name] = rgba
for step in steps:
matplotlib.colors.ColorConverter.colors[
f"{name}!{step * 100:.0f}"
] = scale_white_amount(values, step)
matplotlib.colors.ColorConverter.colors[f"{name}!{step * 100:.0f}"] = (
scale_white_amount(values, step)
)

matplotlib.cm.register_cmap(
name="US", cmap=matplotlib.colors.ListedColormap(list_cmap)
Expand Down
1 change: 0 additions & 1 deletion code/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
self.imbalance = self.imbalance / np.sum(self.imbalance)
self.test_barycentric_subdivisions = 0


def forward(self, batch):
x = self.model(batch)
return x
Expand Down

0 comments on commit 0fea8dc

Please sign in to comment.