diff --git a/elk/__main__.py b/elk/__main__.py index 2df1c2f3..2bb7c090 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -24,7 +24,7 @@ def run(): required=True, ) extract_parser.add_argument( - "--max_gpus", + "--num_gpus", type=int, help="Maximum number of GPUs to use.", required=False, @@ -55,7 +55,7 @@ def run(): args = parser.parse_args() if args.command == "extract": - extract(args.extraction, args.max_gpus).save_to_disk(args.output) + extract(args.extraction, args.num_gpus).save_to_disk(args.output) elif args.command == "elicit": train(args.run, args.output) elif args.command == "eval": diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 9d06efaf..8633a8b3 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -32,7 +32,7 @@ class EvaluateConfig(Serializable): target: ExtractionConfig source: str = field(positional=True) normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" - max_gpus: int = -1 + num_gpus: int = -1 def evaluate_reporter( @@ -74,7 +74,7 @@ def evaluate_reporter( def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): - ds = extract(cfg.target, max_gpus=cfg.max_gpus) + ds = extract(cfg.target, num_gpus=cfg.num_gpus) layers = [ int(feat[len("hidden_") :]) @@ -82,7 +82,7 @@ def evaluate_reporters(cfg: EvaluateConfig, out_dir: Optional[Path] = None): if feat.startswith("hidden_") ] - devices = select_usable_devices(cfg.max_gpus) + devices = select_usable_devices(cfg.num_gpus) num_devices = len(devices) transfer_eval = elk_reporter_dir() / cfg.source / "transfer_eval" diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index b68a9ed4..57e7b5bb 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -209,7 +209,7 @@ def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) -def extract(cfg: ExtractionConfig, max_gpus: int = -1) -> DatasetDict: +def extract(cfg: ExtractionConfig, num_gpus: int = -1) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: @@ -271,7 +271,7 @@ def get_splits() -> SplitDict: length=num_variants, ), } - devices = select_usable_devices(max_gpus) + devices = select_usable_devices(num_gpus) builders = { split_name: _GeneratorBuilder( cache_dir=None, diff --git a/elk/training/train.py b/elk/training/train.py index ce47d9a3..747af4b8 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -42,7 +42,7 @@ class RunConfig(Serializable): data: Config specifying hidden states on which the reporter will be trained. net: Config for building the reporter network. optim: Config for the `.fit()` loop. - max_gpus: The maximum number of GPUs to use. Defaults to -1, which means + num_gpus: The maximum number of GPUs to use. Defaults to -1, which means "use all available GPUs". normalization: The normalization method to use. Defaults to "meanonly". See `elk.training.preprocessing.normalize()` for details. @@ -58,7 +58,7 @@ class RunConfig(Serializable): ) optim: OptimConfig = field(default_factory=OptimConfig) - max_gpus: int = -1 + num_gpus: int = -1 normalization: Literal["legacy", "none", "elementwise", "meanonly"] = "meanonly" skip_baseline: bool = False debug: bool = False @@ -170,7 +170,7 @@ def train_reporter( def train(cfg: RunConfig, out_dir: Optional[Path] = None): # Extract the hidden states first if necessary - ds = extract(cfg.data, max_gpus=cfg.max_gpus) + ds = extract(cfg.data, num_gpus=cfg.num_gpus) if out_dir is None: out_dir = memorably_named_dir(elk_reporter_dir()) @@ -189,7 +189,7 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): with open(out_dir / "metadata.yaml", "w") as meta_f: yaml.dump(meta, meta_f) - devices = select_usable_devices(cfg.max_gpus) + devices = select_usable_devices(cfg.num_gpus) num_devices = len(devices) cols = [ diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index 96e4390c..5f9d576c 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -1,12 +1,16 @@ """Utilities that use PyNVML to get GPU usage info, and select GPUs accordingly.""" +from .typing import assert_type import os import pynvml import torch import warnings +import time -def select_usable_devices(max_gpus: int = -1, *, min_memory: int = 0) -> list[str]: +def select_usable_devices( + num_gpus: int = -1, *, min_memory: int = -1, max_wait_time: int = 2 * 60 * 60 +) -> list[str]: """Select a set of devices that have at least `min_memory` bytes of free memory. When there are more than enough GPUs to satisfy the request, the GPUs with the @@ -20,50 +24,54 @@ def select_usable_devices(max_gpus: int = -1, *, min_memory: int = 0) -> list[st only recently (commit `dc4f2af` on 9 Feb. 2023) implemented in PyTorch `master`. We can't depend on PyTorch nightly and we also don't want to copy-paste the code here. - For now, we simply return `list(range(max_gpus))` whenever `CUDA_VISIBLE_DEVICES` + For now, we simply return `list(range(num_gpus))` whenever `CUDA_VISIBLE_DEVICES` is set. Arguably this is expected behavior. If the user set `CUDA_VISIBLE_DEVICES`, they probably want to use all & only those GPUs. Args: - num_gpus: Maximum number of GPUs to select. If negative, all available GPUs + num_gpus: Number of GPUs to select. If negative, all available GPUs meeting the criteria will be selected. min_memory: Minimum amount of free memory (in bytes) required to select a GPU. + If negative, `min_memory` is set to 90% of the per-GPU memory. + max_wait_time: Maximum time (in seconds) to wait for the requested number of + GPUs to become available. Defaults to 2 hours. Returns: - A list of suitable PyTorch device strings, in ascending numerical order. + A list of suitable PyTorch device strings, in ascending numerical order, with + exactly `num_gpus` elements. Raises: - ValueError: If `max_gpus` is greater than the number of visible GPUs. + ValueError: If `num_gpus` is greater than the number of visible GPUs. """ # Trivial case: no GPUs requested or available num_visible = torch.cuda.device_count() - if max_gpus == 0 or num_visible == 0: + if num_gpus == 0 or num_visible == 0: return ["cpu"] # Sanity checks - if max_gpus > num_visible: + if num_gpus > num_visible: raise ValueError( - f"Requested {max_gpus} GPUs, but only {num_visible} are visible." + f"Requested {num_gpus} GPUs, but only {num_visible} are visible." ) - elif max_gpus < 0: - max_gpus = num_visible + elif num_gpus < 0: + num_gpus = num_visible # No limits, so try to use all installed GPUs - if max_gpus == num_visible and min_memory <= 0: + if num_gpus == num_visible and min_memory == 0: print(f"Using all {num_visible} GPUs.") - return [f"cuda:{i}" for i in range(max_gpus)] + return [f"cuda:{i}" for i in range(num_gpus)] # The user set CUDA_VISIBLE_DEVICES and also requested a specific number of GPUs. # The environment variable takes precedence, so we'll just use all visible GPUs. - count_msg = "all" if max_gpus == num_visible else f"first {max_gpus}" + count_msg = "all" if num_gpus == num_visible else f"first {num_gpus}" if "CUDA_VISIBLE_DEVICES" in os.environ: warnings.warn( f"Smart GPU selection not supported when CUDA_VISIBLE_DEVICES is set. " f"Will use {count_msg} visible devices." ) - return [f"cuda:{i}" for i in range(max_gpus)] + return [f"cuda:{i}" for i in range(num_gpus)] - # pynvml.nvmlInit() will raise if we're using non-NVIDIA GPUs + # Initialize PyNVML try: pynvml.nvmlInit() except pynvml.NVMLError: @@ -71,46 +79,83 @@ def select_usable_devices(max_gpus: int = -1, *, min_memory: int = 0) -> list[st f"Unable to initialize PyNVML; are you using non-NVIDIA GPUs? Will use " f"{count_msg} visible devices." ) - return [f"cuda:{i}" for i in range(max_gpus)] + return [f"cuda:{i}" for i in range(num_gpus)] - try: - # PyNVML and PyTorch device indices should agree when CUDA_VISIBLE_DEVICES is - # not set. We need them to agree so that the PyNVML indices match the PyTorch - # indices, and we don't have to do any complex error-prone conversions. - num_installed = pynvml.nvmlDeviceGetCount() - assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count" - - # List of (-free memory, GPU index) tuples. Sorted descending by free memory, - # then ascending by GPU index. - memories_and_indices = sorted( + # PyNVML and PyTorch device indices should agree when CUDA_VISIBLE_DEVICES is + # not set. We need them to agree so that the PyNVML indices match the PyTorch + # indices, and we don't have to do any complex error-prone conversions. + num_installed = pynvml.nvmlDeviceGetCount() + assert num_installed == num_visible, "PyNVML and PyTorch disagree on GPU count" + + # Set default value for `min_memory` + if min_memory < 0: + min_device_ram = min( ( - -int(pynvml.nvmlDeviceGetMemoryInfo(handle).free), - pynvml.nvmlDeviceGetIndex(handle), + assert_type( + int, + pynvml.nvmlDeviceGetMemoryInfo( + pynvml.nvmlDeviceGetHandleByIndex(idx) + ).total, + ) + for idx in range(num_installed) ) - for handle in map(pynvml.nvmlDeviceGetHandleByIndex, range(num_installed)) ) - usable_indices = [ - index for neg_mem, index in memories_and_indices if -neg_mem >= min_memory - ] - finally: - # Make sure we always shut down PyNVML + min_memory = int(0.9 * min_device_ram) + + # Get free memory for each GPU + num_tries = 1 + start_time = time.time() + while (time.time() - start_time) < max_wait_time: + # check if at least `num_gpus` GPUs have at least `min_memory` + # bytes of free memory + + try: + # List of (-free memory, GPU index) tuples. Sorted descending by + # free memory, then ascending by GPU index. + memories_and_indices = sorted( + ( + -int(pynvml.nvmlDeviceGetMemoryInfo(handle).free), + pynvml.nvmlDeviceGetIndex(handle), + ) + for handle in map( + pynvml.nvmlDeviceGetHandleByIndex, range(num_installed) + ) + ) + usable_indices = [ + index + for neg_mem, index in memories_and_indices + if -neg_mem >= min_memory + ] + if len(usable_indices) >= num_gpus: + break + elif num_tries % 60 == 0: # Print every 10 minutes + print( + f"Waiting for {num_gpus} GPUs with " + f"at least {min_memory / 10 ** 9:.2f} GB " + f"of free memory. {len(usable_indices)} GPUs currently available." + ) + except Exception as e: + warnings.warn( + f"Unable to query GPU memory: {e}. Will try again in 10 seconds." + ) + + # Wait a bit before trying again + time.sleep(10) + num_tries += 1 + else: pynvml.nvmlShutdown() + raise RuntimeError( + f"Unable to find {num_gpus} GPUs" + f"with at least {min_memory / 10 ** 9:.2f} GB " + f"of free memory after {max_wait_time} seconds." + ) + pynvml.nvmlShutdown() - # Indices are sorted descending by free memory, so we want the first `max_gpus` + # Indices are sorted descending by free memory, so we want the first `num_gpus` # items. For printing purposes, though, we sort the indices numerically. - selection = sorted(usable_indices[:max_gpus]) + selection = sorted(usable_indices[:num_gpus]) - # Did we get the maximum number of GPUs requested? - if len(selection) == max_gpus: - print(f"Using {len(selection)} of {num_visible} GPUs: {selection}") - else: - print(f"Using {len(selection)} of {max_gpus} requested GPUs: {selection}") - print( - f"{num_visible - len(selection)} GPUs have insufficient free memory " - f"({min_memory / 10 ** 9:.2f} GB needed)." - ) + assert len(selection) == num_gpus + print(f"Using {len(selection)} of {num_visible} GPUs: {selection}") - if len(selection) > 0: - return [f"cuda:{i}" for i in selection] - else: - return ["cpu"] + return [f"cuda:{i}" for i in selection]