diff --git a/lightning_ir/models/mvr/config.py b/lightning_ir/models/mvr/config.py index 56aa1d3..44e34da 100644 --- a/lightning_ir/models/mvr/config.py +++ b/lightning_ir/models/mvr/config.py @@ -4,14 +4,11 @@ class MVRConfig(BiEncoderConfig): model_type = "mvr" TOKENIZER_ARGS = BiEncoderConfig.TOKENIZER_ARGS.union({ - "add_viewer_tokens", "num_viewer_tokens", }) def __init__(self, - add_viewer_tokens = True, - num_viewer_tokens = 8, + num_viewer_tokens: int | None = 8, **kwargs): - super().__init__(**kwargs) - self.add_viewer_tokens = add_viewer_tokens + super().__init__(**kwargs) self.num_viewer_tokens = num_viewer_tokens \ No newline at end of file diff --git a/lightning_ir/models/mvr/tokenizer.py b/lightning_ir/models/mvr/tokenizer.py index 655e5b4..03f1cdc 100644 --- a/lightning_ir/models/mvr/tokenizer.py +++ b/lightning_ir/models/mvr/tokenizer.py @@ -1,7 +1,21 @@ +from tokenizers.processors import TemplateProcessing + class MVRTokenizer(BiEncoderTokenizer): def __init__( self, - add_viewer_tokens: bool = True, + num_viewer_tokens: int = 8, ): super().__init__() - \ No newline at end of file + if num_viewer_tokens is not None: + viewer_tokens = " ".join(f"[VIE{idx}]" for idx in num_viewer_tokens) + self.add_tokens(viewer_tokens, special_tokens=True) + self.doc_post_processor = TemplateProcessing( + single=f"{viewer_tokens} {self.DOC_TOKEN} $0 [SEP]", + pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", + special_tokens=[ + ("[CLS]", self.cls_token_id), + ("[SEP]", self.sep_token_id), + (self.QUERY_TOKEN, self.query_token_id), + (self.DOC_TOKEN, self.doc_token_id), + ], + ) \ No newline at end of file