Skip to content

Commit

Permalink
Use monkeypatch to set env var in test
Browse files Browse the repository at this point in the history
Signed-off-by: Pooya Davoodi <[email protected]>
  • Loading branch information
pooyadavoodi committed Dec 12, 2024
1 parent 6666445 commit 657cf02
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions tests/models/embedding/language/test_gritlm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib.util
import math
import os
from array import array
from typing import List

Expand All @@ -17,10 +16,8 @@
# GritLM embedding implementation is only supported by XFormers backend.
pytest.mark.skipif(not importlib.util.find_spec("xformers"),
reason="GritLM requires XFormers")
os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS"

MODEL_NAME = "parasail-ai/GritLM-7B-vllm"

MAX_MODEL_LEN = 4000


Expand All @@ -31,7 +28,10 @@ def _arr(arr):
return array("i", arr)


def test_find_array():
def test_find_array(monkeypatch):
# GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

from vllm.model_executor.models.gritlm import GritLMPooler

# Create an LLM object to get the model config.
Expand All @@ -51,9 +51,13 @@ def test_find_array():

@pytest.fixture(scope="module")
def server_embedding():
args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
# GritLM embedding implementation is only supported by XFormers backend.
with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -131,7 +135,10 @@ def validate_embed_output(q_rep: List[float], d_rep: List[float]):
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)


def test_gritlm_offline_embedding():
def test_gritlm_offline_embedding(monkeypatch):
# GritLM embedding implementation is only supported by XFormers backend.
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")

queries, q_instruction, documents, d_instruction = get_test_data()

llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
Expand Down

0 comments on commit 657cf02

Please sign in to comment.