diff --git a/elk/logging.py b/elk/debug_log.py similarity index 100% rename from elk/logging.py rename to elk/debug_log.py diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index fe711891..21d58897 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -18,6 +18,7 @@ from elk.training.preprocessing import normalize from ..extraction import ExtractionConfig, extract +from ..training import EigenReporter, CcsReporter, Reporter from ..files import elk_reporter_dir, memorably_named_dir from ..utils import ( assert_type, @@ -61,16 +62,22 @@ def evaluate_reporter( method=cfg.normalization, ) - reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}.pt" - reporter = torch.load(reporter_path, map_location=device) + # load yaml to dict + with open(elk_reporter_dir() / cfg.source / "cfg.yaml", "r") as f: + source_cfg = yaml.safe_load(f) + is_eigen = "neg_cov_weight" in source_cfg["net"] # TODO: this is a hack + + reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}" + reporter = ( + EigenReporter.load(reporter_path, device=device) + if is_eigen + else CcsReporter.load(reporter_path, device=device) + ) reporter.eval() test_x0, test_x1 = test_h.unbind(dim=-2) - test_result = reporter.score( - (test_x0, test_x1), - test_labels, - ) + test_result = reporter.score(test_labels, test_x0, test_x1) stats = [layer, *test_result] return stats @@ -104,7 +111,8 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): cols = ["layer", "loss", "acc", "cal_acc", "auroc"] # Evaluate reporters for each layer in parallel - with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: + ctx = mp.get_context("spawn") + with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: fn = partial( evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices ) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index b68a9ed4..3ab6c291 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -289,6 +289,9 @@ def get_splits() -> SplitDict: ) for (split_name, split_info) in splits.items() } + import multiprocess as mp + + mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] ds = dict() for split, builder in builders.items(): diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 54fc2292..2fc66349 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -118,6 +118,10 @@ def __init__( ) ) + @classmethod + def _cfg_from_dict(cls, d: dict) -> CcsReporterConfig: + return CcsReporterConfig.from_dict(d) + def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: loss = sum( LOSSES[name](logit0, logit1, coef) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 27d54ddc..e1a84eaf 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -151,6 +151,10 @@ def supervised_interclass_cov(self) -> Tensor: ) # 2 x d return cat_mat.T @ cat_mat / 2 + @classmethod + def _cfg_from_dict(cls, d: dict) -> EigenReporterConfig: + return EigenReporterConfig.from_dict(d) + def clear(self) -> None: """Clear the running statistics of the reporter.""" self.contrastive_xcov_M2.zero_() diff --git a/elk/training/reporter.py b/elk/training/reporter.py index ea8a4406..fa2a1059 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -4,6 +4,7 @@ from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass +import json from pathlib import Path from simple_parsing.helpers import Serializable from sklearn.metrics import roc_auc_score @@ -147,15 +148,43 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None: self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n - # TODO: These methods will do something fancier in the future + def save(self, path: Union[Path, str]): + """Save separate JSON and PT files for the reporter.""" + + # json state + json_path = Path(path).with_suffix(".json") + with open(json_path, "w") as f: + j_dict = {"cfg": self.config.to_dict(), "in_features": self.in_features} + f.write(json.dumps(j_dict)) + + # state dict + state_path = Path(path).with_suffix(".pt") + torch.save(self.state_dict(), state_path) + @classmethod - def load(cls, path: Union[Path, str]): - """Load a reporter from a file.""" - return torch.load(path) + def load(cls, path: Union[Path, str], device: Optional[str] = None): + """Load a reporter from a directory containing a JSON and a state dict.""" - def save(self, path: Union[Path, str]): - # TODO: Save separate JSON and PT files for the reporter. - torch.save(self, path) + # json state + json_path = Path(path).with_suffix(".json") + with open(json_path, "r") as f: + j_dict = json.loads(f.read()) + cfg = cls._cfg_from_dict(j_dict["cfg"]) + in_features = j_dict["in_features"] + + # state dict + state_path = Path(path).with_suffix(".pt") + state_dict = torch.load(state_path) + + reporter = cls(in_features, cfg, device=device) + reporter.load_state_dict(state_dict) + return reporter + + @classmethod + def _cfg_from_dict(cls, d: dict): + """Create a ReporterConfig from a dictionary. + This is a separate method so that subclasses can override it.""" + return ReporterConfig.from_dict(d) @abstractmethod def fit( diff --git a/elk/training/train.py b/elk/training/train.py index 2b4b215b..9236500b 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -8,7 +8,7 @@ select_usable_devices, int16_to_float32, ) -from ..logging import save_debug_log +from ..debug_log import save_debug_log from .classifier import Classifier from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig @@ -168,8 +168,7 @@ def train_reporter( with open(lr_dir / f"layer_{layer}.pt", "wb") as file: pickle.dump(lr_model, file) - with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: - torch.save(reporter, file) + reporter.save(reporter_dir / f"layer_{layer}") return stats @@ -217,7 +216,8 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): if feat.startswith("hidden_") ] # Train reporters for each layer in parallel - with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: + ctx = mp.get_context("spawn") + with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: fn = partial( train_reporter, cfg, ds, out_dir, devices=devices, world_size=num_devices )