Skip to content

Commit

Permalink
train and val
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Aug 18, 2023
1 parent ecb5e2b commit 8113317
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
8 changes: 3 additions & 5 deletions elk/training/platt_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import torch
from einops import rearrange, repeat
from rich import print
from torch import Tensor, nn, optim

from elk.metrics import to_one_hot

from rich import print


class PlattMixin(ABC):
"""Mixin for classifier-like objects that can be Platt scaled."""
Expand Down Expand Up @@ -58,9 +57,8 @@ def closure():
return float(loss)

opt.step(closure)
from elk.utils.write_print_all import write_print_all


print("platt losses", losses)
print("scale", self.scale.item())
print("bias", self.bias.item())

56 changes: 33 additions & 23 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import torch
from einops import rearrange, repeat
from rich import print
from simple_parsing import subgroups

from ..evaluation import Eval
Expand All @@ -21,8 +22,6 @@
from .eigen_reporter import EigenFitter, EigenFitterConfig
from .multi_reporter import MultiReporter, ReporterWithInfo, SingleReporter

from rich import print


def evaluate_and_save(
train_loss: float | None,
Expand Down Expand Up @@ -148,9 +147,10 @@ def make_eval(self, model, eval_dataset):

# Create a separate function to handle the reporter training.
def train_and_save_reporter(
self, device, layer, out_dir, train_dict, prompt_index=None
self, device, layer, out_dir, train_dict, val_dict, prompt_index=None
) -> ReporterWithInfo:
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
(first_val_h, val_gt, _), *_ = val_dict.values()
(_, v, k, d) = first_train_h.shape
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
raise ValueError("All datasets must have the same hidden state size")
Expand All @@ -165,31 +165,41 @@ def train_and_save_reporter(
train_loss = None
if isinstance(self.net, CcsConfig):
assert len(train_dict) == 1, "CCS only supports single-task training"

def train_rep(net):
print(f"train and eval ({net.platt_burns})")
reporter = CcsReporter(net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)
reporter.platt_scale(train_gt, first_train_h)
cred = reporter(first_train_h)
stats = evaluate_preds(train_gt, cred, PromptEnsembling.FULL).to_dict()
stats['train_loss'] = train_loss
stats['scale'] = reporter.scale.item()
stats['bias'] = reporter.bias.item()
print(stats)
return reporter, stats
jdi_reporter, jdi_stats = train_rep(self.net)

def eval_stats(gt, h):
cred = reporter(h)
stats = evaluate_preds(gt, cred, PromptEnsembling.FULL).to_dict()
stats["train_loss"] = train_loss
stats["scale"] = reporter.scale.item()
stats["bias"] = reporter.bias.item()
return stats
return reporter, eval_stats(train_gt, first_train_h), eval_stats(val_gt, first_val_h)

cpy_net = replace(self.net, platt_burns="hack")
hack_reporter, hack_stats = train_rep(cpy_net)
vanilla_net = replace(self.net, platt_burns="")
vanilla_reporter, vanilla_stats = train_rep(vanilla_net)

df_b = pd.DataFrame([jdi_stats], index=['jdi_stats'])
df_c = pd.DataFrame([hack_stats], index=['hack_stats'])
df_a = pd.DataFrame([vanilla_stats], index=['vanilla_stats'])
df = pd.concat([df_a, df_b, df_c])
print(df)
df.to_csv(out_dir / f"stats_layer_{layer}.csv")
reporter = jdi_reporter
vanilla_net = replace(self.net, platt_burns="vanilla")

trains = []
vals = []
reporters = []
for net in [self.net, cpy_net, vanilla_net]:
reporter, train_stats, val_stats = train_rep(net)
train_df = pd.DataFrame([train_stats], index=[net.platt_burns])
val_df = pd.DataFrame([val_stats], index=[net.platt_burns])
trains.append(train_df)
vals.append(val_df)
reporters.append(reporter)

dfs = [pd.concat(part_dfs).T for part_dfs in [trains, vals]]
out_dir.mkdir(parents=True, exist_ok=True)
dfs[0].to_csv(out_dir / f"stats_train_layer_{layer}.csv")
dfs[1].to_csv(out_dir / f"stats_val_layer_{layer}.csv")
reporter = reporters[0]
elif isinstance(self.net, EigenFitterConfig):
fitter = EigenFitter(
self.net, d, num_classes=k, num_variants=v, device=device
Expand Down Expand Up @@ -295,7 +305,7 @@ def apply_to_layer(
train_loss = maybe_multi_reporter.train_loss
else:
reporter_train_result = self.train_and_save_reporter(
device, layer, self.out_dir / "reporters", train_dict
device, layer, self.out_dir / "reporters", train_dict, val_dict
)

maybe_multi_reporter = reporter_train_result.model
Expand Down

0 comments on commit 8113317

Please sign in to comment.