Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
lisalehna committed Nov 29, 2024
1 parent b1895a7 commit 62854e4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
7 changes: 5 additions & 2 deletions lightning_ir/models/mvr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
class MVRConfig(BiEncoderConfig):
model_type = "mvr"

ADDED_ARGS = BiEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})
TOKENIZER_ARGS = BiEncoderConfig.TOKENIZER_ARGS.union({
"add_viewer_tokens",
"num_viewer_tokens",
})

def __init__(self,
add_viewer_tokens = True,
num_viewer_tokens = 8,
**kwargs):
super().__init__(**kwargs)
self.add_viewer_tokens = add_viewer_tokens
self.num_viewer_tokens = num_viewer_tokens
self.num_viewer_tokens = num_viewer_tokens
2 changes: 1 addition & 1 deletion lightning_ir/models/mvr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightning_ir import BiEncoderModel, BiEncoderOutput

class MVRModel(BiEncoderModel):
config_class = MVRConfig
config = MVRConfig

def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
Expand Down
3 changes: 1 addition & 2 deletions lightning_ir/models/mvr/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
class MVRModule(BiEncoderModule):
def __init__(
self,

self,
):
super().__init__()
if self.config.add_viewer_tokens and len(self.tokenizer) > self.config.vocab_size:
Expand Down

0 comments on commit 62854e4

Please sign in to comment.