Skip to content

Commit 657cf02

Browse files
committed
Use monkeypatch to set env var in test
Signed-off-by: Pooya Davoodi <[email protected]>
1 parent 6666445 commit 657cf02

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

tests/models/embedding/language/test_gritlm.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import importlib.util
22
import math
3-
import os
43
from array import array
54
from typing import List
65

@@ -17,10 +16,8 @@
1716
# GritLM embedding implementation is only supported by XFormers backend.
1817
pytest.mark.skipif(not importlib.util.find_spec("xformers"),
1918
reason="GritLM requires XFormers")
20-
os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS"
2119

2220
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
23-
2421
MAX_MODEL_LEN = 4000
2522

2623

@@ -31,7 +28,10 @@ def _arr(arr):
3128
return array("i", arr)
3229

3330

34-
def test_find_array():
31+
def test_find_array(monkeypatch):
32+
# GritLM embedding implementation is only supported by XFormers backend.
33+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
34+
3535
from vllm.model_executor.models.gritlm import GritLMPooler
3636

3737
# Create an LLM object to get the model config.
@@ -51,9 +51,13 @@ def test_find_array():
5151

5252
@pytest.fixture(scope="module")
5353
def server_embedding():
54-
args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
55-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
56-
yield remote_server
54+
# GritLM embedding implementation is only supported by XFormers backend.
55+
with pytest.MonkeyPatch.context() as mp:
56+
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
57+
58+
args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
59+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
60+
yield remote_server
5761

5862

5963
@pytest.fixture(scope="module")
@@ -131,7 +135,10 @@ def validate_embed_output(q_rep: List[float], d_rep: List[float]):
131135
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)
132136

133137

134-
def test_gritlm_offline_embedding():
138+
def test_gritlm_offline_embedding(monkeypatch):
139+
# GritLM embedding implementation is only supported by XFormers backend.
140+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
141+
135142
queries, q_instruction, documents, d_instruction = get_test_data()
136143

137144
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)

0 commit comments

Comments
 (0)