diff --git a/C_MTEB/eval_C-MTEB.py b/C_MTEB/eval_C-MTEB.py index 0d176393..810d6abf 100644 --- a/C_MTEB/eval_C-MTEB.py +++ b/C_MTEB/eval_C-MTEB.py @@ -19,6 +19,8 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', default="BAAI/bge-large-zh", type=str) parser.add_argument('--task_type', default=None, type=str) + parser.add_argument('--add_instruction', action='store_true', help="whether to add instruction for query") + parser.add_argument('--pooling_method', default='cls', type=str) return parser.parse_args() @@ -27,7 +29,8 @@ def get_args(): args = get_args() model = FlagDRESModel(model_name_or_path=args.model_name_or_path, - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:") + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + pooling_method=args.pooling_method) task_names = [t.description["name"] for t in MTEB(task_types=args.task_type, task_langs=['zh', 'zh-CN']).tasks] @@ -40,9 +43,11 @@ def get_args(): 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval', 'T2Reranking', 'MMarcoReranking', 'CMedQAv1', 'CMedQAv2']: if args.model_name_or_path not in query_instruction_for_retrieval_dict: - instruction = "为这个句子生成表示以用于检索相关文章:" - # instruction = None - print(f"{args.model_name_or_path} not in query_instruction_for_retrieval_dict, set instruction=为这个句子生成表示以用于检索相关文章:") + if args.add_instruction: + instruction = "为这个句子生成表示以用于检索相关文章:" + else: + instruction = None + print(f"{args.model_name_or_path} not in query_instruction_for_retrieval_dict, set instruction={instruction}") else: instruction = query_instruction_for_retrieval_dict[args.model_name_or_path] else: diff --git a/C_MTEB/eval_MTEB.py b/C_MTEB/eval_MTEB.py index d9159e52..f44cb9e9 100644 --- a/C_MTEB/eval_MTEB.py +++ b/C_MTEB/eval_MTEB.py @@ -18,6 +18,7 @@ def get_args(): parser.add_argument('--model_name_or_path', default="BAAI/bge-large-en", type=str) parser.add_argument('--task_type', default=None, type=str, help="task type. Default is None, which means using all task types") parser.add_argument('--add_instruction', action='store_true', help="whether to add instruction for query") + parser.add_argument('--pooling_method', default='cls', type=str) return parser.parse_args() @@ -27,7 +28,8 @@ def get_args(): model = FlagDRESModel(model_name_or_path=args.model_name_or_path, normalize_embeddings=False, # normlize embedding will harm the performance of classification task - query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ") + query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ", + pooling_method=args.pooling_method) task_names = [t.description["name"] for t in MTEB(task_types=args.task_type, task_langs=['en']).tasks] diff --git a/C_MTEB/flag_dres_model.py b/C_MTEB/flag_dres_model.py index 67d95776..ce82cfa6 100644 --- a/C_MTEB/flag_dres_model.py +++ b/C_MTEB/flag_dres_model.py @@ -73,7 +73,7 @@ def encode(self, sentences: List[str], **kwargs) -> np.ndarray: max_length=512, ).to(self.device) last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state - embeddings = last_hidden_state[:, 0] + embeddings = self.pooling(last_hidden_state, inputs['attention_mask']) if self.normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, dim=-1) embeddings = cast(torch.Tensor, embeddings) @@ -81,9 +81,14 @@ def encode(self, sentences: List[str], **kwargs) -> np.ndarray: return np.concatenate(all_embeddings, axis=0) - - - - + def pooling(self, + last_hidden_state: torch.Tensor, + attention_mask: torch.Tensor=None): + if self.pooling_method == 'cls': + return last_hidden_state[:, 0] + elif self.pooling_method == 'mean': + s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1) + d = attention_mask.sum(dim=1, keepdim=True).float() + return s / d