From f873bed8b79e8e75d38a8c0db1771e0dbad70b5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Thu, 28 Nov 2024 16:52:39 +0000 Subject: [PATCH] fixed --- rankers/modelling/sparse.py | 98 ++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 18 deletions(-) diff --git a/rankers/modelling/sparse.py b/rankers/modelling/sparse.py index a06544d..94e3142 100644 --- a/rankers/modelling/sparse.py +++ b/rankers/modelling/sparse.py @@ -1,5 +1,6 @@ from rankers.modelling.dot import Pooler import torch +import torch.nn as nn from transformers import ( AutoModel, AutoConfig, @@ -11,14 +12,67 @@ from torch.nn import functional as F +class ProcessingConstructor: + def __init__(self, + norm = "none", + activation = "relu", + aggregation = "max", + ) -> None: + self._norm = { + 'none': nn.Identity(), + 'log1p': lambda x: torch.log(1 + x), + }[norm] + self._activation = { + 'none': nn.Identity(), + 'relu': nn.ReLU(), + }[activation] + self._aggregation = { + 'max': lambda x: torch.max(x, dim=1).values, + 'mean': lambda x: torch.mean(x, dim=1), + 'sum': lambda x: torch.sum(x, dim=1), + }[activation] + + def _get_norm(self, norm_value): + if norm_value == "log1p": + return lambda x: torch.log(1 + x) + else: + return nn.Identity() + + def _get_activation(self, activation_value): + if activation_value == "relu": + return nn.ReLU() + else: + return nn.Identity() + + def _get_aggregation(self, aggregation_value): + if aggregation_value == "max": + return lambda x: torch.max(x, dim=1).values + elif aggregation_value == "mean": + return lambda x: torch.mean(x, dim=1) + elif aggregation_value == "sum": + return lambda x: torch.sum(x, dim=1) + else: + return nn.Identity() + + def __call__(self, x, mask): + post_act = self._activation(x) + norm = self._norm(post_act) + masked = norm * mask.unsqueeze(-1) + return self._aggregation(masked) + + class SparseConfig(DotConfig): model_type = "Sparse" def __init__( self, model_name_or_path: str = "bert-base-uncased", - query_processing: str = "splade_max", - doc_processing: str = "splade_max", + query_norm="log1p", + query_activation="relu", + query_aggregation="max", + doc_norm="log1p", + doc_activation="relu", + doc_aggregation="max", pooling_type="none", inbatch_loss=None, model_tied=True, @@ -39,15 +93,23 @@ def __init__( pooler_tied, **kwargs, ) - self.query_processing = query_processing - self.doc_processing = doc_processing + self.query_activation = query_activation + self.query_norm = query_norm + self.query_aggregation = query_aggregation + self.doc_activation = doc_activation + self.doc_norm = doc_norm + self.doc_aggregation = doc_aggregation @classmethod def from_pretrained( cls, model_name_or_path: str = "bert-base-uncased", - query_processing: str = "splade_max", - doc_processing: str = "splade_max", + query_activation="relu", + query_norm="log1p", + query_aggregation="max", + doc_activation="relu", + doc_norm="log1p", + doc_aggregation="max", pooling_type="none", inbatch_loss=None, model_tied=True, @@ -66,8 +128,12 @@ def from_pretrained( pooler_dim_out, pooler_tied, ) - config.query_processing = query_processing - config.doc_processing = doc_processing + config.query_activation = query_activation + config.query_norm = query_norm + config.query_aggregation = query_aggregation + config.doc_activation = doc_activation + config.doc_norm = doc_norm + config.doc_aggregation = doc_aggregation return config @@ -95,16 +161,12 @@ def __init__( ): super().__init__(model, tokenizer, config, model_d, pooler) - self.query_processing = ( - splade_max - if config.query_processing == "splade_max" - else lambda x, y: x.logits - ) - self.doc_processing = ( - splade_max - if config.doc_processing == "splade_max" - else lambda x, y: x.logits - ) + self.query_processing = ProcessingConstructor(norm=config.query_norm, + activation=config.query_activation, + aggregation=config.query_aggregation) + self.doc_processing = ProcessingConstructor(norm=config.doc_norm, + activation=config.doc_activation, + aggregation=config.doc_aggregation) from .pyterrier.sparse import SparseTransformer