From ccdc190cc09b63754e1e7bbcb6943fc501426b74 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Sun, 9 Apr 2023 13:32:35 -0400 Subject: [PATCH] force spawn start method --- elk/evaluation/evaluate.py | 3 ++- elk/extraction/extraction.py | 3 +++ elk/training/train.py | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index fe711891..a50290aa 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -104,7 +104,8 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): cols = ["layer", "loss", "acc", "cal_acc", "auroc"] # Evaluate reporters for each layer in parallel - with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: + ctx = mp.get_context("spawn") + with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: fn = partial( evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices ) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index b68a9ed4..3ab6c291 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -289,6 +289,9 @@ def get_splits() -> SplitDict: ) for (split_name, split_info) in splits.items() } + import multiprocess as mp + + mp.set_start_method("spawn", force=True) # type: ignore[attr-defined] ds = dict() for split, builder in builders.items(): diff --git a/elk/training/train.py b/elk/training/train.py index 2b4b215b..2d809880 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -217,7 +217,8 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): if feat.startswith("hidden_") ] # Train reporters for each layer in parallel - with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: + ctx = mp.get_context("spawn") + with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f: fn = partial( train_reporter, cfg, ds, out_dir, devices=devices, world_size=num_devices )