Skip to content

Commit

Permalink
tests are done
Browse files Browse the repository at this point in the history
  • Loading branch information
KaivuH committed Feb 22, 2023
1 parent b861bbd commit b333c76
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
26 changes: 16 additions & 10 deletions elk/extraction/extraction_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,24 @@ def extract(split: str):

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
print(f"Loading model '{cfg.model}'...")
model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto")
print(f"Done. Model class: '{model.__class__.__name__}'")

# Intelligently select a GPU with enough memory
if dist.is_initialized():
model.to(f"cuda:{dist.get_rank()}")
elif torch.cuda.is_available():
# We at least need enough VRAM to hold the model parameters
min_memory = sum(p.element_size() * p.numel() for p in model.parameters())
(device_idx,) = select_usable_gpus(max_gpus=1, min_memory=min_memory)
model.to(f"cuda:{device_idx}")
if dist.is_initialized() or torch.cuda.is_available():
print(f"Loading model '{cfg.model}'...")
model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto")
print(f"Done. Model class: '{model.__class__.__name__}'")
if dist.is_initialized():
model.to(f"cuda:{dist.get_rank()}")
elif torch.cuda.is_available():
# We at least need enough VRAM to hold the model parameters
min_memory = sum(p.element_size() * p.numel() for p in model.parameters())
(device_idx,) = select_usable_gpus(max_gpus=1, min_memory=min_memory)
model.to(f"cuda:{device_idx}")
else:
print("No GPUs detected. Running on CPU.")
print(f"Loading model '{cfg.model}'...")
model = AutoModel.from_pretrained(cfg.model, torch_dtype=torch.float32)
print(f"Done. Model class: '{model.__class__.__name__}'")

if cfg.use_encoder_states and not model.config.is_encoder_decoder:
raise ValueError(
Expand Down
11 changes: 4 additions & 7 deletions tests/test_elicit_extract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from elk.math import stochastic_round_constrained
from hypothesis import given, strategies as st
from random import Random
import math
import numpy as np
import pytest
import os

Expand All @@ -16,7 +11,8 @@ def test_elk_run():
assert (
os.waitstatus_to_exitcode(
os.system(
"elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10"
"NO_CUDA=1 elk elicit"
+ " microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10"
)
)
== 0
Expand All @@ -25,7 +21,8 @@ def test_elk_run():
assert (
os.waitstatus_to_exitcode(
os.system(
"elk extract microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10"
"NO_CUDA=1 elk extract"
+ " microsoft/deberta-v2-xxlarge-mnli imdb --max_examples 10"
)
)
== 0
Expand Down

0 comments on commit b333c76

Please sign in to comment.