1
1
import importlib .util
2
2
import math
3
- import os
4
3
from array import array
5
4
from typing import List
6
5
17
16
# GritLM embedding implementation is only supported by XFormers backend.
18
17
pytest .mark .skipif (not importlib .util .find_spec ("xformers" ),
19
18
reason = "GritLM requires XFormers" )
20
- os .environ ["VLLM_ATTENTION_BACKEND" ] = "XFORMERS"
21
19
22
20
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
23
-
24
21
MAX_MODEL_LEN = 4000
25
22
26
23
@@ -31,7 +28,10 @@ def _arr(arr):
31
28
return array ("i" , arr )
32
29
33
30
34
- def test_find_array ():
31
+ def test_find_array (monkeypatch ):
32
+ # GritLM embedding implementation is only supported by XFormers backend.
33
+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , "XFORMERS" )
34
+
35
35
from vllm .model_executor .models .gritlm import GritLMPooler
36
36
37
37
# Create an LLM object to get the model config.
@@ -51,9 +51,13 @@ def test_find_array():
51
51
52
52
@pytest .fixture (scope = "module" )
53
53
def server_embedding ():
54
- args = ["--task" , "embedding" , "--max_model_len" , str (MAX_MODEL_LEN )]
55
- with RemoteOpenAIServer (MODEL_NAME , args ) as remote_server :
56
- yield remote_server
54
+ # GritLM embedding implementation is only supported by XFormers backend.
55
+ with pytest .MonkeyPatch .context () as mp :
56
+ mp .setenv ("VLLM_ATTENTION_BACKEND" , "XFORMERS" )
57
+
58
+ args = ["--task" , "embedding" , "--max_model_len" , str (MAX_MODEL_LEN )]
59
+ with RemoteOpenAIServer (MODEL_NAME , args ) as remote_server :
60
+ yield remote_server
57
61
58
62
59
63
@pytest .fixture (scope = "module" )
@@ -131,7 +135,10 @@ def validate_embed_output(q_rep: List[float], d_rep: List[float]):
131
135
assert math .isclose (cosine_sim_q1_d1 , 0.532 , abs_tol = 0.001 )
132
136
133
137
134
- def test_gritlm_offline_embedding ():
138
+ def test_gritlm_offline_embedding (monkeypatch ):
139
+ # GritLM embedding implementation is only supported by XFormers backend.
140
+ monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , "XFORMERS" )
141
+
135
142
queries , q_instruction , documents , d_instruction = get_test_data ()
136
143
137
144
llm = vllm .LLM (MODEL_NAME , task = "embedding" , max_model_len = MAX_MODEL_LEN )
0 commit comments