Skip to content

Commit

Permalink
webis-de#4 viewertokens to vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
lisalehna committed Dec 2, 2024
1 parent 62854e4 commit fa59ac3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
7 changes: 2 additions & 5 deletions lightning_ir/models/mvr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ class MVRConfig(BiEncoderConfig):
model_type = "mvr"

TOKENIZER_ARGS = BiEncoderConfig.TOKENIZER_ARGS.union({
"add_viewer_tokens",
"num_viewer_tokens",
})

def __init__(self,
add_viewer_tokens = True,
num_viewer_tokens = 8,
num_viewer_tokens: int | None = 8,
**kwargs):
super().__init__(**kwargs)
self.add_viewer_tokens = add_viewer_tokens
super().__init__(**kwargs)
self.num_viewer_tokens = num_viewer_tokens
18 changes: 16 additions & 2 deletions lightning_ir/models/mvr/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from tokenizers.processors import TemplateProcessing

class MVRTokenizer(BiEncoderTokenizer):
def __init__(
self,
add_viewer_tokens: bool = True,
num_viewer_tokens: int = 8,
):
super().__init__()

if num_viewer_tokens is not None:
viewer_tokens = " ".join(f"[VIE{idx}]" for idx in num_viewer_tokens)
self.add_tokens(viewer_tokens, special_tokens=True)
self.doc_post_processor = TemplateProcessing(
single=f"{viewer_tokens} {self.DOC_TOKEN} $0 [SEP]",
pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.cls_token_id),
("[SEP]", self.sep_token_id),
(self.QUERY_TOKEN, self.query_token_id),
(self.DOC_TOKEN, self.doc_token_id),
],
)

0 comments on commit fa59ac3

Please sign in to comment.