Skip to content

Commit

Permalink
lm preds
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Sep 21, 2023
1 parent 83c7642 commit 98a8dea
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
22 changes: 18 additions & 4 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,35 @@ def apply_to_layer(
lr=dict(),
)
for mode in ("none", "full"):
# TODO save lm logprobs and add to buf
if val_data.lm_log_odds is not None:
if self.save_logprobs:
out_logprobs[ds_name]["lm"][mode] = get_logprobs(
val_data.lm_log_odds, mode
).cpu()
row_bufs["lm_eval"].append(
{
"ensembling": mode,
**meta,
**evaluate_preds(
val_data.labels, val_data.lm_log_odds, mode
).to_dict(),
}
)

for i, model in enumerate(lr_models):
model.eval()
val_credences = model(val_data.hiddens)
val_log_odds = model(val_data.hiddens)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_credences, mode
val_log_odds, mode
).cpu()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(
val_data.labels, val_credences, mode
val_data.labels, val_log_odds, mode
).to_dict(),
}
)
Expand Down
10 changes: 8 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
class LayerData:
hiddens: Tensor
labels: Tensor
lm_preds: Tensor | None
lm_log_odds: Tensor | None
texts: list[list[str]] # (n, v)
row_ids: list[int] # (n,)
variant_ids: list[list[str]] # (n, v)
Expand Down Expand Up @@ -156,10 +156,16 @@ def prepare_data(
if self.prompt_indices:
hiddens = hiddens[:, self.prompt_indices]

if "lm_log_odds" in split.column_names:
with split.formatted_as("torch", device=device):
lm_preds = assert_type(Tensor, split["lm_log_odds"])
else:
lm_preds = None

out[ds_name] = LayerData(
hiddens=hiddens,
labels=labels,
lm_preds=None, # TODO: implement
lm_log_odds=lm_preds,
texts=split["texts"],
row_ids=split["row_id"],
variant_ids=split["variant_ids"],
Expand Down
26 changes: 21 additions & 5 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,38 @@ def apply_to_layer(
)

for mode in ("none", "full"):
if val.lm_log_odds is not None:
if self.save_logprobs:
out_logprobs[ds_name]["lm"][mode] = (
get_logprobs(val.lm_log_odds, mode).detach().cpu()
)

row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(
val.labels, val.lm_log_odds, mode
).to_dict(),
}
)

for i, model in enumerate(lr_models):
model.eval()
val_credences = model(val.hiddens)
train_credences = model(train.hiddens)
val_log_odds = model(val.hiddens)
train_log_odds = model(train.hiddens)

if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = (
get_logprobs(val_credences, mode).detach().cpu()
get_logprobs(val_log_odds, mode).detach().cpu()
)

row_bufs["lr_eval"].append(
{
**meta,
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(val.labels, val_credences, mode).to_dict(),
**evaluate_preds(val.labels, val_log_odds, mode).to_dict(),
}
)

Expand All @@ -114,7 +130,7 @@ def apply_to_layer(
"ensembling": mode,
"inlp_iter": i,
**evaluate_preds(
train.labels, train_credences, mode
train.labels, train_log_odds, mode
).to_dict(),
}
)
Expand Down

0 comments on commit 98a8dea

Please sign in to comment.