Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 18, 2023
1 parent 8113317 commit 5306313
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 0 additions & 1 deletion elk/training/platt_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def closure():

opt.step(closure)


print("platt losses", losses)
print("scale", self.scale.item())
print("bias", self.bias.item())
11 changes: 8 additions & 3 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,21 @@ def train_rep(net):
reporter = CcsReporter(net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)
reporter.platt_scale(train_gt, first_train_h)

def eval_stats(gt, h):
cred = reporter(h)
stats = evaluate_preds(gt, cred, PromptEnsembling.FULL).to_dict()
stats["train_loss"] = train_loss
stats["scale"] = reporter.scale.item()
stats["bias"] = reporter.bias.item()
return stats
return reporter, eval_stats(train_gt, first_train_h), eval_stats(val_gt, first_val_h)


return (
reporter,
eval_stats(train_gt, first_train_h),
eval_stats(val_gt, first_val_h),
)

cpy_net = replace(self.net, platt_burns="hack")
vanilla_net = replace(self.net, platt_burns="vanilla")

Expand Down

0 comments on commit 5306313

Please sign in to comment.