diff --git a/lightning_ir/models/mvr/tokenizer.py b/lightning_ir/models/mvr/tokenizer.py index 03f1cdc..e89593a 100644 --- a/lightning_ir/models/mvr/tokenizer.py +++ b/lightning_ir/models/mvr/tokenizer.py @@ -7,15 +7,26 @@ def __init__( ): super().__init__() if num_viewer_tokens is not None: - viewer_tokens = " ".join(f"[VIE{idx}]" for idx in num_viewer_tokens) + viewer_tokens = " ".join(f"[VIE{idx}]" for idx in range(num_viewer_tokens)) self.add_tokens(viewer_tokens, special_tokens=True) + viewer_token_ids = [(f"[VIE{viewer_token_id}]", self.viewer_token_id(f"[VIE{viewer_token_id}]")) for viewer_token_id in range(num_viewer_tokens)] self.doc_post_processor = TemplateProcessing( - single=f"{viewer_tokens} {self.DOC_TOKEN} $0 [SEP]", - pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", + single=f"{viewer_tokens} $0 [SEP]", + pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] $B:1 [SEP]:1", special_tokens=[ ("[CLS]", self.cls_token_id), ("[SEP]", self.sep_token_id), - (self.QUERY_TOKEN, self.query_token_id), - (self.DOC_TOKEN, self.doc_token_id), + *viewer_token_ids, ], - ) \ No newline at end of file + ) + + + def viewer_token_id(self, viewer_token_id) -> int | None: + """The token id of the query token if marker tokens are added. + + :return: Token id of the query token + :rtype: int | None + """ + if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder: + return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"] + return None \ No newline at end of file diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index a6aea9b..e8379c2 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -4,7 +4,7 @@ from lightning_ir.base.tokenizer import LightningIRTokenizerClassFactory from lightning_ir.bi_encoder.config import BiEncoderConfig from lightning_ir.cross_encoder.config import CrossEncoderConfig - +from lightning_ir.models.mvr import MVRConfig @pytest.mark.parametrize( "config", @@ -69,3 +69,19 @@ def test_cross_encoder_tokenizer(model_name_or_path: str): encoding = tokenizer.tokenize(query, doc)["encoding"] assert encoding is not None assert len(encoding.input_ids[0]) == tokenizer.query_length + tokenizer.doc_length + 3 + +def test_mvr_tokenizer(model_name_or_path: str): + Tokenizer = LightningIRTokenizerClassFactory(MVRConfig).from_pretrained(model_name_or_path) + tokenizer = Tokenizer.from_pretrained(model_name_or_path, query_length=2, doc_length=4) + + query = "What is the capital of France?" + doc = "Paris is the capital of France." + encoding = tokenizer.tokenize(query, doc)["encoding"] + assert encoding is not None + assert len(encoding.input_ids) == tokenizer.query_length + tokenizer.doc_length + 3 + + query = ["What is the capital of France?"] + doc = ["Paris is the capital of France."] + encoding = tokenizer.tokenize(query, doc)["encoding"] + assert encoding is not None + assert len(encoding.input_ids[0]) == tokenizer.query_length + tokenizer.doc_length + 3