Skip to content

Commit

Permalink
held-out platt scaling; random basline reporter; fix encoder only
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 5, 2023
1 parent 8864724 commit 8672104
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
4 changes: 3 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def extract_hiddens(
# Record the EXACT question we fed to the model
variant_questions.append(text)

inputs = dict(input_ids=ids.long(), labels=labels)
inputs = dict(input_ids=ids.long())
if labels is not None:
inputs["labels"] = labels
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
Expand Down
93 changes: 82 additions & 11 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class Elicit(Run):
"""Full specification of a reporter training run."""

net: ReporterConfig = subgroups(
{"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen"
{"ccs": CcsReporterConfig, "eigen": EigenReporterConfig},
default="eigen", # type: ignore
)
"""Config for building the reporter network."""

Expand Down Expand Up @@ -55,6 +56,7 @@ def apply_to_layer(
layer: int,
devices: list[str],
world_size: int,
num_platt: int = 100,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

Expand All @@ -76,22 +78,52 @@ def apply_to_layer(
if not all(other_h.shape[-2] == k for other_h, _, _ in rest):
raise ValueError("All datasets must have the same number of classes")

# carve out a subset of the training data for Platt scaling
original_train_dict = train_dict
train_dict = dict()
platt_set_h = []
platt_set_gt = []
for ds_name, (
o_train_h,
o_train_gt,
o_train_lm_preds,
) in original_train_dict.items():
perm = torch.randperm(o_train_h.shape[0])
(_, v, _, _) = o_train_h.shape
sub_idx = perm[num_platt:]
platt_set_idx = perm[:num_platt]
platt_set_h.append(
rearrange(o_train_h[platt_set_idx], "n v k d -> (n v k) d")
)
platt_set_gt.append(
to_one_hot(
repeat(o_train_gt[platt_set_idx], "n -> (n v)", v=v), k
).flatten()
)
train_dict[ds_name] = (
o_train_h[sub_idx],
o_train_gt[sub_idx],
o_train_lm_preds[sub_idx] if o_train_lm_preds is not None else None,
)
platt_set_h = torch.cat(platt_set_h)
platt_set_gt = torch.cat(platt_set_gt)

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_subs = defaultdict(list)
train_sub_credences = defaultdict(list)
val_credences = defaultdict(list)
train_rand_credences = defaultdict(list)
val_rand_credences = defaultdict(list)
for num_train in self.num_trains:
for i in range(self.num_samples):
# get a random subset of the training data
train_sub = dict()
for ds_name, (train_h, train_gt, train_lm_preds) in train_dict.items():
train_idx = torch.randperm(train_h.shape[0])[:num_train]
train_h_sub = train_h[train_idx]
train_gt_sub = train_gt[train_idx]
sub_idx = torch.randperm(train_h.shape[0])[:num_platt]
train_h_sub = train_h[sub_idx]
train_gt_sub = train_gt[sub_idx]
train_lm_preds_sub = (
train_lm_preds[train_idx]
if train_lm_preds is not None
else None
train_lm_preds[sub_idx] if train_lm_preds is not None else None
)
train_sub[ds_name] = (train_h_sub, train_gt_sub, train_lm_preds_sub)
train_subs[num_train].append(train_sub)
Expand All @@ -108,10 +140,11 @@ def apply_to_layer(
(val_h, val_gt, _) = next(iter(val_dict.values()))

# TODO: Enable Platt scaling for CCS once normalization is fixed
# And add random reporter
# (_, v, k, _) = first_train_h.shape
# reporter.platt_scale(
# to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
# rearrange(first_train_h, "n v k d -> (n v k) d"),
# to_one_hot(repeat(platt_gt, "n -> (n v)", v=v), k).flatten(),
# rearrange(platt_h, "n v k d -> (n v k) d"),
# )

elif isinstance(self.net, EigenReporterConfig):
Expand All @@ -135,9 +168,33 @@ def apply_to_layer(

reporter.fit_streaming()
reporter.platt_scale(
torch.cat(label_list),
torch.cat(hidden_list),
platt_set_gt,
platt_set_h,
)

rand_reporter = EigenReporter(
self.net, d, num_classes=k, device=device
)
rand_reporter.weight = torch.randn_like(rand_reporter.weight)
rand_reporter.norm = reporter.norm
rand_reporter.platt_scale(
platt_set_gt,
platt_set_h,
)

train_rand_credences[num_train].append(
{
ds_name: rand_reporter(train_h)
for ds_name, (train_h, _, _) in train_dict.items()
}
)
val_rand_credences[num_train].append(
{
ds_name: rand_reporter(val_h)
for ds_name, (val_h, _, _) in val_dict.items()
}
)

else:
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

Expand Down Expand Up @@ -200,6 +257,20 @@ def apply_to_layer(
mode,
).to_dict()
)
num_train_buf["train_rand_eval"].append(
evaluate_preds(
train_gt,
train_rand_credences[num_train][i][ds_name],
mode,
).to_dict()
)
num_train_buf["rand_eval"].append(
evaluate_preds(
val_gt,
val_rand_credences[num_train][i][ds_name],
mode,
).to_dict()
)

nt_dfs = {k: pd.DataFrame(v) for k, v in num_train_buf.items()}
# get mean, std, min, max, and 95% CI of each of
Expand Down

0 comments on commit 8672104

Please sign in to comment.