diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 3f5a223d..2d60e2a3 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -61,18 +61,24 @@ def extract(split: str): # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. - print(f"Loading model '{cfg.model}'...") - model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto") - print(f"Done. Model class: '{model.__class__.__name__}'") # Intelligently select a GPU with enough memory - if dist.is_initialized(): - model.to(f"cuda:{dist.get_rank()}") - elif torch.cuda.is_available(): - # We at least need enough VRAM to hold the model parameters - min_memory = sum(p.element_size() * p.numel() for p in model.parameters()) - (device_idx,) = select_usable_gpus(max_gpus=1, min_memory=min_memory) - model.to(f"cuda:{device_idx}") + if dist.is_initialized() or torch.cuda.is_available(): + print(f"Loading model '{cfg.model}'...") + model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto") + print(f"Done. Model class: '{model.__class__.__name__}'") + if dist.is_initialized(): + model.to(f"cuda:{dist.get_rank()}") + elif torch.cuda.is_available(): + # We at least need enough VRAM to hold the model parameters + min_memory = sum(p.element_size() * p.numel() for p in model.parameters()) + (device_idx,) = select_usable_gpus(max_gpus=1, min_memory=min_memory) + model.to(f"cuda:{device_idx}") + else: + print("No GPUs detected. Running on CPU.") + print(f"Loading model '{cfg.model}'...") + model = AutoModel.from_pretrained(cfg.model, torch_dtype=torch.float32) + print(f"Done. Model class: '{model.__class__.__name__}'") if cfg.use_encoder_states and not model.config.is_encoder_decoder: raise ValueError( diff --git a/tests/test_elicit_extract.py b/tests/test_elicit_extract.py index a2ebe2fc..935180f7 100644 --- a/tests/test_elicit_extract.py +++ b/tests/test_elicit_extract.py @@ -1,8 +1,3 @@ -from elk.math import stochastic_round_constrained -from hypothesis import given, strategies as st -from random import Random -import math -import numpy as np import pytest import os @@ -16,7 +11,8 @@ def test_elk_run(): assert ( os.waitstatus_to_exitcode( os.system( - "elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10" + "NO_CUDA=1 elk elicit" + + " microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10" ) ) == 0 @@ -25,7 +21,8 @@ def test_elk_run(): assert ( os.waitstatus_to_exitcode( os.system( - "elk extract microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10" + "NO_CUDA=1 elk extract" + + " microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10" ) ) == 0