Skip to content

Commit

Permalink
idiot
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 4, 2024
1 parent e0e7e4f commit b30a1ee
Showing 1 changed file with 63 additions and 63 deletions.
126 changes: 63 additions & 63 deletions rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,69 @@
from more_itertools import chunked
from ..train.loss import batched_dot_product, cross_dot_product, LOSS_REGISTRY

class DotConfig(PretrainedConfig):
"""Configuration for Dot Model
Parameters
----------
model_name_or_path : str
the model name or path
mode : str
the pooling mode for the model
model_tied : bool
whether the model is tied
use_pooler : bool
whether to use the pooler
pooler_dim_in : int
the input dimension for the pooler
pooler_dim_out : int
the output dimension for the pooler
pooler_tied : bool
whether the pooler is tied
"""
model_architecture = "Dot"
def __init__(self,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
pooler_tied=True,
**kwargs):
self.model_name_or_path = model_name_or_path
self.mode = mode
self.inbatch_loss = inbatch_loss
self.model_tied = model_tied
self.use_pooler = use_pooler
self.pooler_dim_in = pooler_dim_in
self.pooler_dim_out = pooler_dim_out
self.pooler_tied = pooler_tied
super().__init__(**kwargs)

@classmethod
def from_pretrained(cls,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
pooler_tied=True,
) -> 'DotConfig':
config = super().from_pretrained(model_name_or_path)
config.model_name_or_path = model_name_or_path
config.mode = mode
config.inbatch_loss = inbatch_loss
config.model_tied = model_tied
config.use_pooler = use_pooler
config.pooler_dim_in = pooler_dim_in
config.pooler_dim_out = pooler_dim_out
config.pooler_tied = pooler_tied
return config

class DotTransformer(pt.Transformer):
cls_architecture = AutoModel
def __init__(self,
Expand Down Expand Up @@ -191,69 +254,6 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
def __repr__(self):
return f'{repr(self.bi_model_model)}.scorer()'

class DotConfig(PretrainedConfig):
"""Configuration for Dot Model
Parameters
----------
model_name_or_path : str
the model name or path
mode : str
the pooling mode for the model
model_tied : bool
whether the model is tied
use_pooler : bool
whether to use the pooler
pooler_dim_in : int
the input dimension for the pooler
pooler_dim_out : int
the output dimension for the pooler
pooler_tied : bool
whether the pooler is tied
"""
model_architecture = "Dot"
def __init__(self,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
pooler_tied=True,
**kwargs):
self.model_name_or_path = model_name_or_path
self.mode = mode
self.inbatch_loss = inbatch_loss
self.model_tied = model_tied
self.use_pooler = use_pooler
self.pooler_dim_in = pooler_dim_in
self.pooler_dim_out = pooler_dim_out
self.pooler_tied = pooler_tied
super().__init__(**kwargs)

@classmethod
def from_pretrained(cls,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
pooler_tied=True,
) -> 'DotConfig':
config = super().from_pretrained(model_name_or_path)
config.model_name_or_path = model_name_or_path
config.mode = mode
config.inbatch_loss = inbatch_loss
config.model_tied = model_tied
config.use_pooler = use_pooler
config.pooler_dim_in = pooler_dim_in
config.pooler_dim_out = pooler_dim_out
config.pooler_tied = pooler_tied
return config

class Pooler(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down

0 comments on commit b30a1ee

Please sign in to comment.