Skip to content

Commit

Permalink
add option to skip check_separability
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Apr 9, 2023
1 parent 3f7187b commit 158a754
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class RunConfig(Serializable):
max_gpus: int = -1
normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly"
skip_baseline: bool = False
skip_check_separability: bool = False
debug: bool = False


Expand Down Expand Up @@ -104,15 +105,18 @@ def train_reporter(
val_x0, val_x1 = val_h.unbind(dim=-2)

with torch.no_grad():
pseudo_auroc = Reporter.check_separability(
train_pair=(x0, x1), val_pair=(val_x0, val_x1)
)
if pseudo_auroc > 0.6:
warnings.warn(
f"The pseudo-labels at layer {layer} are linearly separable with "
f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the "
f"algorithm will not converge to a good solution."
if not cfg.skip_check_separability:
pseudo_auroc = Reporter.check_separability(
train_pair=(x0, x1), val_pair=(val_x0, val_x1)
)
if pseudo_auroc > 0.6:
warnings.warn(
f"The pseudo-labels at layer {layer} are linearly separable "
f"with AUROC of {pseudo_auroc:.3f}. This may indicate that "
f"the algorithm will not converge to a good solution."
)
else:
pseudo_auroc = None

if isinstance(cfg.net, CcsReporterConfig):
reporter = CcsReporter(x0.shape[-1], cfg.net, device=device)
Expand All @@ -133,7 +137,9 @@ def train_reporter(

lr_dir.mkdir(parents=True, exist_ok=True)
reporter_dir.mkdir(parents=True, exist_ok=True)
stats = [layer, pseudo_auroc, train_loss, *val_result]
stats = [layer, train_loss, *val_result]
if pseudo_auroc is not None:
stats = [pseudo_auroc, *stats]

if not cfg.skip_baseline:
# repeat_interleave makes `num_variants` copies of each label, all within a
Expand Down Expand Up @@ -194,7 +200,6 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):

cols = [
"layer",
"pseudo_auroc",
"train_loss",
"acc",
"cal_acc",
Expand All @@ -203,6 +208,8 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
]
if not cfg.skip_baseline:
cols += ["lr_auroc", "lr_acc"]
if not cfg.skip_check_separability:
cols = ["pseudo_auroc"] + cols

layers = [
int(feat[len("hidden_") :])
Expand Down

0 comments on commit 158a754

Please sign in to comment.