Skip to content

Commit

Permalink
webis-de#5 viewer token id, test
Browse files Browse the repository at this point in the history
  • Loading branch information
lisalehna committed Dec 2, 2024
1 parent fa59ac3 commit 6c56367
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
23 changes: 17 additions & 6 deletions lightning_ir/models/mvr/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,26 @@ def __init__(
):
super().__init__()
if num_viewer_tokens is not None:
viewer_tokens = " ".join(f"[VIE{idx}]" for idx in num_viewer_tokens)
viewer_tokens = " ".join(f"[VIE{idx}]" for idx in range(num_viewer_tokens))
self.add_tokens(viewer_tokens, special_tokens=True)
viewer_token_ids = [(f"[VIE{viewer_token_id}]", self.viewer_token_id(f"[VIE{viewer_token_id}]")) for viewer_token_id in range(num_viewer_tokens)]
self.doc_post_processor = TemplateProcessing(
single=f"{viewer_tokens} {self.DOC_TOKEN} $0 [SEP]",
pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
single=f"{viewer_tokens} $0 [SEP]",
pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.cls_token_id),
("[SEP]", self.sep_token_id),
(self.QUERY_TOKEN, self.query_token_id),
(self.DOC_TOKEN, self.doc_token_id),
*viewer_token_ids,
],
)
)


def viewer_token_id(self, viewer_token_id) -> int | None:
"""The token id of the query token if marker tokens are added.
:return: Token id of the query token
:rtype: int | None
"""
if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder:
return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"]
return None
18 changes: 17 additions & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning_ir.base.tokenizer import LightningIRTokenizerClassFactory
from lightning_ir.bi_encoder.config import BiEncoderConfig
from lightning_ir.cross_encoder.config import CrossEncoderConfig

from lightning_ir.models.mvr import MVRConfig

@pytest.mark.parametrize(
"config",
Expand Down Expand Up @@ -69,3 +69,19 @@ def test_cross_encoder_tokenizer(model_name_or_path: str):
encoding = tokenizer.tokenize(query, doc)["encoding"]
assert encoding is not None
assert len(encoding.input_ids[0]) == tokenizer.query_length + tokenizer.doc_length + 3

def test_mvr_tokenizer(model_name_or_path: str):
Tokenizer = LightningIRTokenizerClassFactory(MVRConfig).from_pretrained(model_name_or_path)
tokenizer = Tokenizer.from_pretrained(model_name_or_path, query_length=2, doc_length=4)

query = "What is the capital of France?"
doc = "Paris is the capital of France."
encoding = tokenizer.tokenize(query, doc)["encoding"]
assert encoding is not None
assert len(encoding.input_ids) == tokenizer.query_length + tokenizer.doc_length + 3

query = ["What is the capital of France?"]
doc = ["Paris is the capital of France."]
encoding = tokenizer.tokenize(query, doc)["encoding"]
assert encoding is not None
assert len(encoding.input_ids[0]) == tokenizer.query_length + tokenizer.doc_length + 3

0 comments on commit 6c56367

Please sign in to comment.