Skip to content

Commit

Permalink
rename probs to logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Aug 31, 2023
1 parent 080174b commit e6def70
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 46 deletions.
30 changes: 15 additions & 15 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def apply_to_layer(
reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)

out_probs = defaultdict(dict)
out_logprobs = defaultdict(dict)
row_bufs = defaultdict(list)
for ds_name, val_data in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

if self.save_probs:
out_probs[ds_name]["texts"] = val_data.text_questions
out_probs[ds_name]["labels"] = val_data.labels.cpu()
out_probs[ds_name]["reporter"] = dict()
out_probs[ds_name]["lr"] = dict()
out_probs[ds_name]["lm"] = dict()
if self.save_logprobs:
out_logprobs[ds_name]["texts"] = val_data.text_questions
out_logprobs[ds_name]["labels"] = val_data.labels.cpu()
out_logprobs[ds_name]["reporter"] = dict()
out_logprobs[ds_name]["lr"] = dict()
out_logprobs[ds_name]["lm"] = dict()

val_credences = reporter(val_data.hiddens)
for mode in ("none", "partial", "full"):
if self.save_probs:
out_probs[ds_name]["reporter"][mode] = get_logprobs(
if self.save_logprobs:
out_logprobs[ds_name]["reporter"][mode] = get_logprobs(
val_credences, mode
).cpu()
out_probs[ds_name]["lm"][mode] = (
out_logprobs[ds_name]["lm"][mode] = (
get_logprobs(val_data.lm_preds, mode).cpu()
if val_data.lm_preds is not None
else None
Expand Down Expand Up @@ -88,8 +88,8 @@ def apply_to_layer(

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
if self.save_probs:
out_probs[ds_name]["lr"][mode] = dict()
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
Expand All @@ -99,8 +99,8 @@ def apply_to_layer(
for i, model in enumerate(lr_models):
model.eval()
val_lr_credences = model(val_data.hiddens)
if self.save_probs:
out_probs[ds_name]["lr"][mode][i] = get_logprobs(
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_lr_credences, mode
).cpu()
row_bufs["lr_eval"].append(
Expand All @@ -114,4 +114,4 @@ def apply_to_layer(
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_probs
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs
36 changes: 20 additions & 16 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Run(ABC, Serializable):
num_gpus: int = -1
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)
save_probs: bool = field(default=False, to_dict=False)
""" saves probs.pt containing {<dsname>: {"texts": [n, v, 2], "labels": [n,]
save_logprobs: bool = field(default=False, to_dict=False)
""" saves logprobs.pt containing {<dsname>: {"texts": [n, v, 2], "labels": [n,]
"lm": {"none": [n, v, 2], "partial": [n, v], "full": [n,]},
"reporter": {<layer>: {"none": [n, v, 2], "partial": [n, v], "full": [n,]}},
"lr": {<layer>: {<inlp_iter>: {"none": ..., "partial": ..., "full": ...}}}
Expand Down Expand Up @@ -196,34 +196,38 @@ def apply_to_layers(
with ctx.Pool(num_devices) as pool:
mapper = pool.imap_unordered if num_devices > 1 else map
df_buffers = defaultdict(list)
probs_dicts = defaultdict(dict)
logprobs_dicts = defaultdict(dict)

try:
for layer, (df_dict, probs_dict) in tqdm(
for layer, (df_dict, logprobs_dict) in tqdm(
zip(layers, mapper(func, layers)), total=len(layers)
):
for k, v in df_dict.items():
df_buffers[k].append(v)
for k, v in probs_dict.items():
probs_dicts[k][layer] = probs_dict[k]
for k, v in logprobs_dict.items():
logprobs_dicts[k][layer] = logprobs_dict[k]
finally:
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by=["layer", "ensembling"])
df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False)
if self.debug:
save_debug_log(self.datasets, self.out_dir)
if self.save_probs:
if self.save_logprobs:
save_dict = defaultdict(dict)
for ds_name, probs_dict in probs_dicts.items():
save_dict[ds_name]["texts"] = probs_dict[layers[0]]["texts"]
save_dict[ds_name]["labels"] = probs_dict[layers[0]]["labels"]
save_dict[ds_name]["lm"] = probs_dict[layers[0]]["lm"]
for ds_name, logprobs_dict in logprobs_dicts.items():
save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"]
save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][
"labels"
]
save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"]
save_dict[ds_name]["reporter"] = dict()
save_dict[ds_name]["lr"] = dict()
for layer, probs_dict_by_mode in probs_dict.items():
save_dict[ds_name]["reporter"][layer] = probs_dict_by_mode[
"reporter"
for layer, logprobs_dict_by_mode in logprobs_dict.items():
save_dict[ds_name]["reporter"][
layer
] = logprobs_dict_by_mode["reporter"]
save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[
"lr"
]
save_dict[ds_name]["lr"][layer] = probs_dict_by_mode["lr"]
torch.save(save_dict, self.out_dir / "probs.pt")
torch.save(save_dict, self.out_dir / "logprobs.pt")
30 changes: 15 additions & 15 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,26 @@ def apply_to_layer(
lr_models = []

row_bufs = defaultdict(list)
out_probs = defaultdict(dict)
out_logprobs = defaultdict(dict)
for ds_name in val_dict:
val, train = val_dict[ds_name], train_dict[ds_name]
meta = {"dataset": ds_name, "layer": layer}

if self.save_probs:
out_probs[ds_name]["texts"] = val.text_questions
out_probs[ds_name]["labels"] = val.labels.cpu()
out_probs[ds_name]["reporter"] = dict()
out_probs[ds_name]["lr"] = dict()
out_probs[ds_name]["lm"] = dict()
if self.save_logprobs:
out_logprobs[ds_name]["texts"] = val.text_questions
out_logprobs[ds_name]["labels"] = val.labels.cpu()
out_logprobs[ds_name]["reporter"] = dict()
out_logprobs[ds_name]["lr"] = dict()
out_logprobs[ds_name]["lm"] = dict()

val_credences = reporter(val.hiddens)
train_credences = reporter(train.hiddens)
for mode in ("none", "partial", "full"):
if self.save_probs:
out_probs[ds_name]["reporter"][mode] = (
if self.save_logprobs:
out_logprobs[ds_name]["reporter"][mode] = (
get_logprobs(val_credences, mode).detach().cpu()
)
out_probs[ds_name]["lm"][mode] = (
out_logprobs[ds_name]["lm"][mode] = (
get_logprobs(val.lm_preds, mode).detach().cpu()
if val.lm_preds is not None
else None
Expand Down Expand Up @@ -201,15 +201,15 @@ def apply_to_layer(
)

if self.supervised != "none":
if self.save_probs:
out_probs[ds_name]["lr"][mode] = dict()
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

for i, model in enumerate(lr_models):
model.eval()
val_lr_credences = model(val.hiddens)
train_lr_credences = model(train.hiddens)
if self.save_probs:
out_probs[ds_name]["lr"][mode][i] = (
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = (
get_logprobs(val_lr_credences, mode).detach().cpu()
)

Expand All @@ -234,4 +234,4 @@ def apply_to_layer(
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_probs
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs

0 comments on commit e6def70

Please sign in to comment.