Skip to content

Commit

Permalink
update C-mteb
Browse files Browse the repository at this point in the history
  • Loading branch information
shitao committed Sep 25, 2023
1 parent d1abaeb commit 982f810
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
13 changes: 9 additions & 4 deletions C_MTEB/eval_C-MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-large-zh", type=str)
parser.add_argument('--task_type', default=None, type=str)
parser.add_argument('--add_instruction', action='store_true', help="whether to add instruction for query")
parser.add_argument('--pooling_method', default='cls', type=str)
return parser.parse_args()


Expand All @@ -27,7 +29,8 @@ def get_args():
args = get_args()

model = FlagDRESModel(model_name_or_path=args.model_name_or_path,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:")
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
pooling_method=args.pooling_method)

task_names = [t.description["name"] for t in MTEB(task_types=args.task_type,
task_langs=['zh', 'zh-CN']).tasks]
Expand All @@ -40,9 +43,11 @@ def get_args():
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'T2Reranking', 'MMarcoReranking', 'CMedQAv1', 'CMedQAv2']:
if args.model_name_or_path not in query_instruction_for_retrieval_dict:
instruction = "为这个句子生成表示以用于检索相关文章:"
# instruction = None
print(f"{args.model_name_or_path} not in query_instruction_for_retrieval_dict, set instruction=为这个句子生成表示以用于检索相关文章:")
if args.add_instruction:
instruction = "为这个句子生成表示以用于检索相关文章:"
else:
instruction = None
print(f"{args.model_name_or_path} not in query_instruction_for_retrieval_dict, set instruction={instruction}")
else:
instruction = query_instruction_for_retrieval_dict[args.model_name_or_path]
else:
Expand Down
4 changes: 3 additions & 1 deletion C_MTEB/eval_MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_args():
parser.add_argument('--model_name_or_path', default="BAAI/bge-large-en", type=str)
parser.add_argument('--task_type', default=None, type=str, help="task type. Default is None, which means using all task types")
parser.add_argument('--add_instruction', action='store_true', help="whether to add instruction for query")
parser.add_argument('--pooling_method', default='cls', type=str)

return parser.parse_args()

Expand All @@ -27,7 +28,8 @@ def get_args():

model = FlagDRESModel(model_name_or_path=args.model_name_or_path,
normalize_embeddings=False, # normlize embedding will harm the performance of classification task
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ")
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
pooling_method=args.pooling_method)

task_names = [t.description["name"] for t in MTEB(task_types=args.task_type,
task_langs=['en']).tasks]
Expand Down
15 changes: 10 additions & 5 deletions C_MTEB/flag_dres_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,22 @@ def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
max_length=512,
).to(self.device)
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
embeddings = last_hidden_state[:, 0]
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())

return np.concatenate(all_embeddings, axis=0)





def pooling(self,
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor=None):
if self.pooling_method == 'cls':
return last_hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d


0 comments on commit 982f810

Please sign in to comment.