Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 28, 2024
1 parent cb55c12 commit f873bed
Showing 1 changed file with 80 additions and 18 deletions.
98 changes: 80 additions & 18 deletions rankers/modelling/sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rankers.modelling.dot import Pooler
import torch
import torch.nn as nn
from transformers import (
AutoModel,
AutoConfig,
Expand All @@ -11,14 +12,67 @@
from torch.nn import functional as F


class ProcessingConstructor:
def __init__(self,
norm = "none",
activation = "relu",
aggregation = "max",
) -> None:
self._norm = {
'none': nn.Identity(),
'log1p': lambda x: torch.log(1 + x),
}[norm]
self._activation = {
'none': nn.Identity(),
'relu': nn.ReLU(),
}[activation]
self._aggregation = {
'max': lambda x: torch.max(x, dim=1).values,
'mean': lambda x: torch.mean(x, dim=1),
'sum': lambda x: torch.sum(x, dim=1),
}[activation]

def _get_norm(self, norm_value):
if norm_value == "log1p":
return lambda x: torch.log(1 + x)
else:
return nn.Identity()

def _get_activation(self, activation_value):
if activation_value == "relu":
return nn.ReLU()
else:
return nn.Identity()

def _get_aggregation(self, aggregation_value):
if aggregation_value == "max":
return lambda x: torch.max(x, dim=1).values
elif aggregation_value == "mean":
return lambda x: torch.mean(x, dim=1)
elif aggregation_value == "sum":
return lambda x: torch.sum(x, dim=1)
else:
return nn.Identity()

def __call__(self, x, mask):
post_act = self._activation(x)
norm = self._norm(post_act)
masked = norm * mask.unsqueeze(-1)
return self._aggregation(masked)


class SparseConfig(DotConfig):
model_type = "Sparse"

def __init__(
self,
model_name_or_path: str = "bert-base-uncased",
query_processing: str = "splade_max",
doc_processing: str = "splade_max",
query_norm="log1p",
query_activation="relu",
query_aggregation="max",
doc_norm="log1p",
doc_activation="relu",
doc_aggregation="max",
pooling_type="none",
inbatch_loss=None,
model_tied=True,
Expand All @@ -39,15 +93,23 @@ def __init__(
pooler_tied,
**kwargs,
)
self.query_processing = query_processing
self.doc_processing = doc_processing
self.query_activation = query_activation
self.query_norm = query_norm
self.query_aggregation = query_aggregation
self.doc_activation = doc_activation
self.doc_norm = doc_norm
self.doc_aggregation = doc_aggregation

@classmethod
def from_pretrained(
cls,
model_name_or_path: str = "bert-base-uncased",
query_processing: str = "splade_max",
doc_processing: str = "splade_max",
query_activation="relu",
query_norm="log1p",
query_aggregation="max",
doc_activation="relu",
doc_norm="log1p",
doc_aggregation="max",
pooling_type="none",
inbatch_loss=None,
model_tied=True,
Expand All @@ -66,8 +128,12 @@ def from_pretrained(
pooler_dim_out,
pooler_tied,
)
config.query_processing = query_processing
config.doc_processing = doc_processing
config.query_activation = query_activation
config.query_norm = query_norm
config.query_aggregation = query_aggregation
config.doc_activation = doc_activation
config.doc_norm = doc_norm
config.doc_aggregation = doc_aggregation
return config


Expand Down Expand Up @@ -95,16 +161,12 @@ def __init__(
):
super().__init__(model, tokenizer, config, model_d, pooler)

self.query_processing = (
splade_max
if config.query_processing == "splade_max"
else lambda x, y: x.logits
)
self.doc_processing = (
splade_max
if config.doc_processing == "splade_max"
else lambda x, y: x.logits
)
self.query_processing = ProcessingConstructor(norm=config.query_norm,
activation=config.query_activation,
aggregation=config.query_aggregation)
self.doc_processing = ProcessingConstructor(norm=config.doc_norm,
activation=config.doc_activation,
aggregation=config.doc_aggregation)

from .pyterrier.sparse import SparseTransformer

Expand Down

0 comments on commit f873bed

Please sign in to comment.