diff --git a/lightning_ir/models/mvr/config.py b/lightning_ir/models/mvr/config.py index 3566c9f..56aa1d3 100644 --- a/lightning_ir/models/mvr/config.py +++ b/lightning_ir/models/mvr/config.py @@ -3,7 +3,10 @@ class MVRConfig(BiEncoderConfig): model_type = "mvr" - ADDED_ARGS = BiEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"}) + TOKENIZER_ARGS = BiEncoderConfig.TOKENIZER_ARGS.union({ + "add_viewer_tokens", + "num_viewer_tokens", + }) def __init__(self, add_viewer_tokens = True, @@ -11,4 +14,4 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.add_viewer_tokens = add_viewer_tokens - self.num_viewer_tokens = num_viewer_tokens \ No newline at end of file + self.num_viewer_tokens = num_viewer_tokens \ No newline at end of file diff --git a/lightning_ir/models/mvr/model.py b/lightning_ir/models/mvr/model.py index 1111993..a0f5744 100644 --- a/lightning_ir/models/mvr/model.py +++ b/lightning_ir/models/mvr/model.py @@ -8,7 +8,7 @@ from lightning_ir import BiEncoderModel, BiEncoderOutput class MVRModel(BiEncoderModel): - config_class = MVRConfig + config = MVRConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) diff --git a/lightning_ir/models/mvr/module.py b/lightning_ir/models/mvr/module.py index c6d1dde..7b47bf8 100644 --- a/lightning_ir/models/mvr/module.py +++ b/lightning_ir/models/mvr/module.py @@ -1,7 +1,6 @@ class MVRModule(BiEncoderModule): def __init__( - self, - + self, ): super().__init__() if self.config.add_viewer_tokens and len(self.tokenizer) > self.config.vocab_size: