Skip to content

Commit

Permalink
log CIs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 4, 2023
1 parent 460941a commit 8864724
Showing 1 changed file with 39 additions and 34 deletions.
73 changes: 39 additions & 34 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def apply_to_layer(

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_subs = defaultdict(list)
reporters = defaultdict(list)
train_sub_credences = defaultdict(list)
val_credences = defaultdict(list)
for num_train in self.num_trains:
for i in range(self.num_samples):
# get a random subset of the training data
Expand All @@ -100,10 +101,9 @@ def apply_to_layer(
len(train_dict) == 1
), "CCS only supports single-task training"

train_h_sub, train_gt_sub, train_lm_preds_sub = train_sub.popitem()[
1
]
train_h_sub, train_gt_sub, _ = train_sub.popitem()[1]
reporter = CcsReporter(self.net, d, device=device)
reporter.fit(train_h_sub, train_gt_sub)

(val_h, val_gt, _) = next(iter(val_dict.values()))

Expand Down Expand Up @@ -133,14 +133,28 @@ def apply_to_layer(
)
reporter.update(train_h_sub)

reporter.fit_streaming()
reporter.platt_scale(
torch.cat(label_list),
torch.cat(hidden_list),
)
else:
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

reporters[num_train].append(reporter)
# grab credences before throwing away the reporter
val_credences[num_train].append(
{
ds_name: reporter(val_h)
for ds_name, (val_h, _, _) in val_dict.items()
}
)
train_sub_credences[num_train].append(
{
ds_name: reporter(train_h_sub)
for ds_name, (train_h_sub, _, _) in train_sub.items()
}
)

# Save reporter checkpoint to disk
reporter.save(
reporter_dir / f"layer_{layer}_num_train_{num_train}_sample_{i}.pt"
Expand Down Expand Up @@ -171,60 +185,51 @@ def apply_to_layer(
num_train_buf = defaultdict(list)

for i in range(self.num_samples):
train_h_sub, train_gt_sub, train_lm_preds_sub = train_subs[
num_train
][i][ds_name]
reporter = reporters[num_train][i]
train_h_sub, train_gt_sub, _ = train_subs[num_train][i][ds_name]
num_train_buf["train_eval"].append(
evaluate_preds(
train_gt_sub,
reporter(train_h_sub),
train_sub_credences[num_train][i][ds_name],
mode,
).to_dict()
)
num_train_buf["eval"].append(
evaluate_preds(
val_gt,
reporter(val_h),
val_credences[num_train][i][ds_name],
mode,
).to_dict()
)

num_train_dfs = {
k: pd.DataFrame(v) for k, v in num_train_buf.items()
}
nt_dfs = {k: pd.DataFrame(v) for k, v in num_train_buf.items()}
# get mean, std, min, max, and 95% CI of each of
# auroc_estimate, acc_estimate, cal_acc_estimate, and ece

for key in num_train_dfs:
for key in nt_dfs:
stats = dict()
for metric in (
"auroc_estimate",
"acc_estimate",
"cal_acc_estimate",
"ece",
):
short_metric = metric.replace("_estimate", "")
sname = metric.replace("_estimate", "")
sem = float(nt_dfs[key][metric].sem()) # type: ignore
stats.update(
{
f"{short_metric}_min": num_train_dfs[key][
metric
].min(),
f"{short_metric}_lower": num_train_dfs[key][
metric
].quantile(0.025),
f"{short_metric}_estimate": num_train_dfs[key][
metric
].mean(),
f"{short_metric}_upper": num_train_dfs[key][
metric
].quantile(0.975),
f"{short_metric}_max": num_train_dfs[key][
metric
].max(),
f"{short_metric}_std": num_train_dfs[key][
metric
].std(),
f"{sname}_estimate": nt_dfs[key][metric].mean(),
f"{sname}_lower": nt_dfs[key][metric].quantile(
0.025
),
f"{sname}_upper": nt_dfs[key][metric].quantile(
0.975
),
f"{sname}_std": nt_dfs[key][metric].std(),
f"{sname}_ci_lower": nt_dfs[key][metric].mean()
- 1.96 * sem,
f"{sname}_ci_upper": nt_dfs[key][metric].mean()
+ 1.96 * sem,
f"{sname}_sem": sem,
}
)
row_bufs[key].append({"num_train": num_train, **meta, **stats})
Expand Down

0 comments on commit 8864724

Please sign in to comment.