Skip to content

Commit

Permalink
force spawn start method
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Apr 9, 2023
1 parent 158a754 commit ccdc190
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None):

cols = ["layer", "loss", "acc", "cal_acc", "auroc"]
# Evaluate reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
evaluate_reporter, cfg, ds, devices=devices, world_size=num_devices
)
Expand Down
3 changes: 3 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def get_splits() -> SplitDict:
)
for (split_name, split_info) in splits.items()
}
import multiprocess as mp

mp.set_start_method("spawn", force=True) # type: ignore[attr-defined]

ds = dict()
for split, builder in builders.items():
Expand Down
3 changes: 2 additions & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
if feat.startswith("hidden_")
]
# Train reporters for each layer in parallel
with mp.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(out_dir / "eval.csv", "w") as f:
fn = partial(
train_reporter, cfg, ds, out_dir, devices=devices, world_size=num_devices
)
Expand Down

0 comments on commit ccdc190

Please sign in to comment.