diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index b947265be9e9d..55c2e5d4ed412 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -35,7 +35,7 @@ def test_find_array(monkeypatch): from vllm.model_executor.models.gritlm import GritLMPooler # Create an LLM object to get the model config. - llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) + llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) pooler = GritLMPooler(model_config=llm.llm_engine.model_config) arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) @@ -55,7 +55,7 @@ def server_embedding(): with pytest.MonkeyPatch.context() as mp: mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)] + args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch): queries, q_instruction, documents, d_instruction = get_test_data() - llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) + llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) d_rep = run_llm_encode( llm, diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ec01a07c16a62..34c1332ac4a66 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -203,12 +203,12 @@ def __init__( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self.task = vllm_config.model_config.task + self.runner_type = vllm_config.model_config.runner_type self._pooler = GritLMPooler(vllm_config.model_config) for layer in self.model.layers: - if self.task == "embedding" and hasattr(layer, "self_attn"): + if self.runner_type == "pooling" and hasattr(layer, "self_attn"): assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( "GritLM embedding is only supported by XFormers backend, " "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") @@ -222,8 +222,8 @@ def forward( **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - # Change attention to non-causal for embedding task. - if self.task == "embedding": + # Change attention to non-causal for pooling tasks. + if self.runner_type == "pooling": assert attn_metadata.prefill_metadata.attn_bias is None attn_metadata.prefill_metadata.attn_bias = [ BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)