From b30a1ee94110c86b373009d953bb11da5b2e8f4c Mon Sep 17 00:00:00 2001 From: Andrew Parry Date: Mon, 4 Nov 2024 13:13:31 +0000 Subject: [PATCH] idiot --- rankers/modelling/dot.py | 126 +++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/rankers/modelling/dot.py b/rankers/modelling/dot.py index 986d654..1e528f4 100644 --- a/rankers/modelling/dot.py +++ b/rankers/modelling/dot.py @@ -12,6 +12,69 @@ from more_itertools import chunked from ..train.loss import batched_dot_product, cross_dot_product, LOSS_REGISTRY +class DotConfig(PretrainedConfig): + """Configuration for Dot Model + + Parameters + ---------- + model_name_or_path : str + the model name or path + mode : str + the pooling mode for the model + model_tied : bool + whether the model is tied + use_pooler : bool + whether to use the pooler + pooler_dim_in : int + the input dimension for the pooler + pooler_dim_out : int + the output dimension for the pooler + pooler_tied : bool + whether the pooler is tied + """ + model_architecture = "Dot" + def __init__(self, + model_name_or_path : str='bert-base-uncased', + mode='cls', + inbatch_loss=None, + model_tied=True, + use_pooler=False, + pooler_dim_in=768, + pooler_dim_out=768, + pooler_tied=True, + **kwargs): + self.model_name_or_path = model_name_or_path + self.mode = mode + self.inbatch_loss = inbatch_loss + self.model_tied = model_tied + self.use_pooler = use_pooler + self.pooler_dim_in = pooler_dim_in + self.pooler_dim_out = pooler_dim_out + self.pooler_tied = pooler_tied + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, + model_name_or_path : str='bert-base-uncased', + mode='cls', + inbatch_loss=None, + model_tied=True, + use_pooler=False, + pooler_dim_in=768, + pooler_dim_out=768, + pooler_tied=True, + ) -> 'DotConfig': + config = super().from_pretrained(model_name_or_path) + config.model_name_or_path = model_name_or_path + config.mode = mode + config.inbatch_loss = inbatch_loss + config.model_tied = model_tied + config.use_pooler = use_pooler + config.pooler_dim_in = pooler_dim_in + config.pooler_dim_out = pooler_dim_out + config.pooler_tied = pooler_tied + return config + class DotTransformer(pt.Transformer): cls_architecture = AutoModel def __init__(self, @@ -191,69 +254,6 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: def __repr__(self): return f'{repr(self.bi_model_model)}.scorer()' -class DotConfig(PretrainedConfig): - """Configuration for Dot Model - - Parameters - ---------- - model_name_or_path : str - the model name or path - mode : str - the pooling mode for the model - model_tied : bool - whether the model is tied - use_pooler : bool - whether to use the pooler - pooler_dim_in : int - the input dimension for the pooler - pooler_dim_out : int - the output dimension for the pooler - pooler_tied : bool - whether the pooler is tied - """ - model_architecture = "Dot" - def __init__(self, - model_name_or_path : str='bert-base-uncased', - mode='cls', - inbatch_loss=None, - model_tied=True, - use_pooler=False, - pooler_dim_in=768, - pooler_dim_out=768, - pooler_tied=True, - **kwargs): - self.model_name_or_path = model_name_or_path - self.mode = mode - self.inbatch_loss = inbatch_loss - self.model_tied = model_tied - self.use_pooler = use_pooler - self.pooler_dim_in = pooler_dim_in - self.pooler_dim_out = pooler_dim_out - self.pooler_tied = pooler_tied - super().__init__(**kwargs) - - @classmethod - def from_pretrained(cls, - model_name_or_path : str='bert-base-uncased', - mode='cls', - inbatch_loss=None, - model_tied=True, - use_pooler=False, - pooler_dim_in=768, - pooler_dim_out=768, - pooler_tied=True, - ) -> 'DotConfig': - config = super().from_pretrained(model_name_or_path) - config.model_name_or_path = model_name_or_path - config.mode = mode - config.inbatch_loss = inbatch_loss - config.model_tied = model_tied - config.use_pooler = use_pooler - config.pooler_dim_in = pooler_dim_in - config.pooler_dim_out = pooler_dim_out - config.pooler_tied = pooler_tied - return config - class Pooler(nn.Module): def __init__(self, config): super().__init__()