From 657cf024c3ef8fe9d4bae3330e28a549bb76296d Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 12 Dec 2024 13:40:35 +0900 Subject: [PATCH] Use monkeypatch to set env var in test Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index b6bc828b46426..b947265be9e9d 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -1,6 +1,5 @@ import importlib.util import math -import os from array import array from typing import List @@ -17,10 +16,8 @@ # GritLM embedding implementation is only supported by XFormers backend. pytest.mark.skipif(not importlib.util.find_spec("xformers"), reason="GritLM requires XFormers") -os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" MODEL_NAME = "parasail-ai/GritLM-7B-vllm" - MAX_MODEL_LEN = 4000 @@ -31,7 +28,10 @@ def _arr(arr): return array("i", arr) -def test_find_array(): +def test_find_array(monkeypatch): + # GritLM embedding implementation is only supported by XFormers backend. + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + from vllm.model_executor.models.gritlm import GritLMPooler # Create an LLM object to get the model config. @@ -51,9 +51,13 @@ def test_find_array(): @pytest.fixture(scope="module") def server_embedding(): - args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + # GritLM embedding implementation is only supported by XFormers backend. + with pytest.MonkeyPatch.context() as mp: + mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + + args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") @@ -131,7 +135,10 @@ def validate_embed_output(q_rep: List[float], d_rep: List[float]): assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) -def test_gritlm_offline_embedding(): +def test_gritlm_offline_embedding(monkeypatch): + # GritLM embedding implementation is only supported by XFormers backend. + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + queries, q_instruction, documents, d_instruction = get_test_data() llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)