diff --git a/mteb/models/gte_models.py b/mteb/models/gte_models.py index 301821b6e0..648fc18850 100644 --- a/mteb/models/gte_models.py +++ b/mteb/models/gte_models.py @@ -2,12 +2,21 @@ from functools import partial +import torch + +from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta from mteb.models.instruct_wrapper import instruct_wrapper -def instruction_template(instruction: str) -> str: - return f"Instruct: {instruction}\nQuery: " if instruction else "" +def instruction_template( + instruction: str, prompt_type: PromptType | None = None +) -> str: + return ( + f"Instruct: {instruction}\nQuery: " + if (prompt_type is None or prompt_type == PromptType.query) and instruction + else "" + ) gte_Qwen2_7B_instruct = ModelMeta( @@ -15,13 +24,14 @@ def instruction_template(instruction: str) -> str: instruct_wrapper, model_name_or_path="Alibaba-NLP/gte-Qwen2-7B-instruct", instruction_template=instruction_template, - attn="cccc", + attn="bbcc", pooling_method="lasttoken", mode="embedding", - torch_dtype="auto", + torch_dtype=torch.float16, # The ST script does not normalize while the HF one does so unclear what to do # https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct#sentence-transformers normalized=True, + embed_eos="<|endoftext|>", ), name="Alibaba-NLP/gte-Qwen2-7B-instruct", languages=None, @@ -44,11 +54,12 @@ def instruction_template(instruction: str) -> str: instruct_wrapper, model_name_or_path="Alibaba-NLP/gte-Qwen1.5-7B-instruct", instruction_template=instruction_template, - attn="cccc", + attn="bbcc", pooling_method="lasttoken", mode="embedding", - torch_dtype="auto", + torch_dtype=torch.float16, normalized=True, + embed_eos="<|endoftext|>", ), name="Alibaba-NLP/gte-Qwen1.5-7B-instruct", languages=["eng_Latn"], @@ -72,11 +83,12 @@ def instruction_template(instruction: str) -> str: instruct_wrapper, model_name_or_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", instruction_template=instruction_template, - attn="cccc", + attn="bbcc", pooling_method="lasttoken", mode="embedding", - torch_dtype="auto", + torch_dtype=torch.float16, normalized=True, + embed_eos="<|endoftext|>", ), name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", languages=["eng_Latn"], diff --git a/mteb/models/instruct_wrapper.py b/mteb/models/instruct_wrapper.py index 96d71970b7..303a386836 100644 --- a/mteb/models/instruct_wrapper.py +++ b/mteb/models/instruct_wrapper.py @@ -66,7 +66,7 @@ def encode( instruction = self.get_instruction(task_name, prompt_type) if self.instruction_template: - instruction = self.format_instruction(instruction) + instruction = self.format_instruction(instruction, prompt_type) logger.info(f"Using instruction: '{instruction}' for task: '{task_name}'") embeddings = super().encode( diff --git a/mteb/models/linq_models.py b/mteb/models/linq_models.py index 48e86ac8d5..4babbf75cf 100644 --- a/mteb/models/linq_models.py +++ b/mteb/models/linq_models.py @@ -4,11 +4,14 @@ import torch +from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta from mteb.models.instruct_wrapper import instruct_wrapper -def instruction_template(instruction: str) -> str: +def instruction_template( + instruction: str, prompt_type: PromptType | None = None +) -> str: return f"Instruct: {instruction}\nQuery: " if instruction else "" diff --git a/mteb/models/nvidia_models.py b/mteb/models/nvidia_models.py index 0c0170de6e..72274b41de 100644 --- a/mteb/models/nvidia_models.py +++ b/mteb/models/nvidia_models.py @@ -16,7 +16,9 @@ logger = logging.getLogger(__name__) -def instruction_template(instruction: str) -> str: +def instruction_template( + instruction: str, prompt_type: PromptType | None = None +) -> str: return f"Instruct: {instruction}\nQuery: " if instruction else "" diff --git a/mteb/models/salesforce_models.py b/mteb/models/salesforce_models.py index e5c0973d5f..b1d45b949c 100644 --- a/mteb/models/salesforce_models.py +++ b/mteb/models/salesforce_models.py @@ -2,11 +2,14 @@ from functools import partial +from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta from mteb.models.instruct_wrapper import instruct_wrapper -def instruction_template(instruction: str) -> str: +def instruction_template( + instruction: str, prompt_type: PromptType | None = None +) -> str: return f"Instruct: {instruction}\nQuery: " if instruction else "" diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index c42a0d8db4..2a9fb20497 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -103,14 +103,16 @@ def get_instruction(task_name: str, prompt_type: PromptType | None) -> str: return task_metadata.prompt return task.abstask_prompt - def format_instruction(self, instruction: str) -> str: + def format_instruction( + self, instruction: str, prompt_type: PromptType | None = None + ) -> str: if isinstance(self.instruction_template, str): if "{instruction}" not in self.instruction_template: raise ValueError( "Instruction template must contain the string '{instruction}'." ) return self.instruction_template.format(instruction=instruction) - return self.instruction_template(instruction) + return self.instruction_template(instruction, prompt_type) def get_task_instruction( self, task_name: str, prompt_type: PromptType | None