Skip to content

Commit

Permalink
force spawn start method
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Apr 9, 2023
1 parent 158a754 commit 7d5cb1b
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 18 deletions.
File renamed without changes.
22 changes: 15 additions & 7 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from elk.training.preprocessing import normalize

from ..extraction import ExtractionConfig, extract
from ..training import EigenReporter, CcsReporter, Reporter
from ..files import elk_reporter_dir, memorably_named_dir
from ..utils import (
assert_type,
Expand Down Expand Up @@ -61,16 +62,22 @@ def evaluate_reporter(
method=cfg.normalization,
)

reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)
# load yaml to dict
with open(elk_reporter_dir() / cfg.source / "cfg.yaml", "r") as f:
source_cfg = yaml.safe_load(f)
is_eigen = "neg_cov_weight" in source_cfg["net"] # TODO: this is a hack

reporter_path = elk_reporter_dir() / cfg.source / "reporters" / f"layer_{layer}"
reporter = (
EigenReporter.load(reporter_path, device=device)
if is_eigen
else CcsReporter.load(reporter_path, device=device)
)
reporter.eval()

test_x0, test_x1 = test_h.unbind(dim=-2)

test_result = reporter.score(
(test_x0, test_x1),
test_labels,
)
test_result = reporter.score(test_labels, test_x0, test_x1)

stats = [layer, *test_result]
return stats
Expand Down Expand Up @@ -104,7 +111,8 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None):

cols = ["layer", "loss", "acc", "cal_acc", "auroc"]
# Evaluate reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices
)
Expand Down
3 changes: 3 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def get_splits() -> SplitDict:
)
for (split_name, split_info) in splits.items()
}
import multiprocess as mp

mp.set_start_method("spawn", force=True) # type: ignore[attr-defined]

ds = dict()
for split, builder in builders.items():
Expand Down
4 changes: 4 additions & 0 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def __init__(
)
)

@classmethod
def _cfg_from_dict(cls, d: dict) -> CcsReporterConfig:
return CcsReporterConfig.from_dict(d)

def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
loss = sum(
LOSSES[name](logit0, logit1, coef)
Expand Down
4 changes: 4 additions & 0 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def supervised_interclass_cov(self) -> Tensor:
) # 2 x d
return cat_mat.T @ cat_mat / 2

@classmethod
def _cfg_from_dict(cls, d: dict) -> EigenReporterConfig:
return EigenReporterConfig.from_dict(d)

def clear(self) -> None:
"""Clear the running statistics of the reporter."""
self.contrastive_xcov_M2.zero_()
Expand Down
43 changes: 36 additions & 7 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .classifier import Classifier
from abc import ABC, abstractmethod
from dataclasses import dataclass
import json
from pathlib import Path
from simple_parsing.helpers import Serializable
from sklearn.metrics import roc_auc_score
Expand Down Expand Up @@ -147,15 +148,43 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None:
self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n
self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n

# TODO: These methods will do something fancier in the future
def save(self, path: Union[Path, str]):
"""Save separate JSON and PT files for the reporter."""

# json state
json_path = Path(path).with_suffix(".json")
with open(json_path, "w") as f:
j_dict = {"cfg": self.config.to_dict(), "in_features": self.in_features}
f.write(json.dumps(j_dict))

# state dict
state_path = Path(path).with_suffix(".pt")
torch.save(self.state_dict(), state_path)

@classmethod
def load(cls, path: Union[Path, str]):
"""Load a reporter from a file."""
return torch.load(path)
def load(cls, path: Union[Path, str], device: Optional[str] = None):
"""Load a reporter from a directory containing a JSON and a state dict."""

def save(self, path: Union[Path, str]):
# TODO: Save separate JSON and PT files for the reporter.
torch.save(self, path)
# json state
json_path = Path(path).with_suffix(".json")
with open(json_path, "r") as f:
j_dict = json.loads(f.read())
cfg = cls._cfg_from_dict(j_dict["cfg"])
in_features = j_dict["in_features"]

# state dict
state_path = Path(path).with_suffix(".pt")
state_dict = torch.load(state_path)

reporter = cls(in_features, cfg, device=device)
reporter.load_state_dict(state_dict)
return reporter

@classmethod
def _cfg_from_dict(cls, d: dict):
"""Create a ReporterConfig from a dictionary.
This is a separate method so that subclasses can override it."""
return ReporterConfig.from_dict(d)

@abstractmethod
def fit(
Expand Down
8 changes: 4 additions & 4 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
select_usable_devices,
int16_to_float32,
)
from ..logging import save_debug_log
from ..debug_log import save_debug_log
from .classifier import Classifier
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .eigen_reporter import EigenReporter, EigenReporterConfig
Expand Down Expand Up @@ -168,8 +168,7 @@ def train_reporter(
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
pickle.dump(lr_model, file)

with open(reporter_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(reporter, file)
reporter.save(reporter_dir / f"layer_{layer}")

return stats

Expand Down Expand Up @@ -217,7 +216,8 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
if feat.startswith("hidden_")
]
# Train reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
train_reporter, cfg, ds, out_dir, devices=devices, world_size=num_devices
)
Expand Down

0 comments on commit 7d5cb1b

Please sign in to comment.