From 81133178698e1f32b0fd96134dd26a3231918f07 Mon Sep 17 00:00:00 2001 From: jon Date: Fri, 18 Aug 2023 20:23:25 +0000 Subject: [PATCH] train and val --- elk/training/platt_scaling.py | 8 ++--- elk/training/train.py | 56 +++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/elk/training/platt_scaling.py b/elk/training/platt_scaling.py index 387011f3..395a1987 100644 --- a/elk/training/platt_scaling.py +++ b/elk/training/platt_scaling.py @@ -3,12 +3,11 @@ import torch from einops import rearrange, repeat +from rich import print from torch import Tensor, nn, optim from elk.metrics import to_one_hot -from rich import print - class PlattMixin(ABC): """Mixin for classifier-like objects that can be Platt scaled.""" @@ -58,9 +57,8 @@ def closure(): return float(loss) opt.step(closure) - - from elk.utils.write_print_all import write_print_all + + print("platt losses", losses) print("scale", self.scale.item()) print("bias", self.bias.item()) - \ No newline at end of file diff --git a/elk/training/train.py b/elk/training/train.py index fb0432a4..ebd04f30 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -7,6 +7,7 @@ import pandas as pd import torch from einops import rearrange, repeat +from rich import print from simple_parsing import subgroups from ..evaluation import Eval @@ -21,8 +22,6 @@ from .eigen_reporter import EigenFitter, EigenFitterConfig from .multi_reporter import MultiReporter, ReporterWithInfo, SingleReporter -from rich import print - def evaluate_and_save( train_loss: float | None, @@ -148,9 +147,10 @@ def make_eval(self, model, eval_dataset): # Create a separate function to handle the reporter training. def train_and_save_reporter( - self, device, layer, out_dir, train_dict, prompt_index=None + self, device, layer, out_dir, train_dict, val_dict, prompt_index=None ) -> ReporterWithInfo: (first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove? + (first_val_h, val_gt, _), *_ = val_dict.values() (_, v, k, d) = first_train_h.shape if not all(other_h.shape[-1] == d for other_h, _, _ in rest): raise ValueError("All datasets must have the same hidden state size") @@ -165,31 +165,41 @@ def train_and_save_reporter( train_loss = None if isinstance(self.net, CcsConfig): assert len(train_dict) == 1, "CCS only supports single-task training" + def train_rep(net): print(f"train and eval ({net.platt_burns})") reporter = CcsReporter(net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) reporter.platt_scale(train_gt, first_train_h) - cred = reporter(first_train_h) - stats = evaluate_preds(train_gt, cred, PromptEnsembling.FULL).to_dict() - stats['train_loss'] = train_loss - stats['scale'] = reporter.scale.item() - stats['bias'] = reporter.bias.item() - print(stats) - return reporter, stats - jdi_reporter, jdi_stats = train_rep(self.net) + + def eval_stats(gt, h): + cred = reporter(h) + stats = evaluate_preds(gt, cred, PromptEnsembling.FULL).to_dict() + stats["train_loss"] = train_loss + stats["scale"] = reporter.scale.item() + stats["bias"] = reporter.bias.item() + return stats + return reporter, eval_stats(train_gt, first_train_h), eval_stats(val_gt, first_val_h) + cpy_net = replace(self.net, platt_burns="hack") - hack_reporter, hack_stats = train_rep(cpy_net) - vanilla_net = replace(self.net, platt_burns="") - vanilla_reporter, vanilla_stats = train_rep(vanilla_net) - - df_b = pd.DataFrame([jdi_stats], index=['jdi_stats']) - df_c = pd.DataFrame([hack_stats], index=['hack_stats']) - df_a = pd.DataFrame([vanilla_stats], index=['vanilla_stats']) - df = pd.concat([df_a, df_b, df_c]) - print(df) - df.to_csv(out_dir / f"stats_layer_{layer}.csv") - reporter = jdi_reporter + vanilla_net = replace(self.net, platt_burns="vanilla") + + trains = [] + vals = [] + reporters = [] + for net in [self.net, cpy_net, vanilla_net]: + reporter, train_stats, val_stats = train_rep(net) + train_df = pd.DataFrame([train_stats], index=[net.platt_burns]) + val_df = pd.DataFrame([val_stats], index=[net.platt_burns]) + trains.append(train_df) + vals.append(val_df) + reporters.append(reporter) + + dfs = [pd.concat(part_dfs).T for part_dfs in [trains, vals]] + out_dir.mkdir(parents=True, exist_ok=True) + dfs[0].to_csv(out_dir / f"stats_train_layer_{layer}.csv") + dfs[1].to_csv(out_dir / f"stats_val_layer_{layer}.csv") + reporter = reporters[0] elif isinstance(self.net, EigenFitterConfig): fitter = EigenFitter( self.net, d, num_classes=k, num_variants=v, device=device @@ -295,7 +305,7 @@ def apply_to_layer( train_loss = maybe_multi_reporter.train_loss else: reporter_train_result = self.train_and_save_reporter( - device, layer, self.out_dir / "reporters", train_dict + device, layer, self.out_dir / "reporters", train_dict, val_dict ) maybe_multi_reporter = reporter_train_result.model