Skip to content

Commit

Permalink
stupid
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 13, 2024
1 parent abc62ee commit eb60442
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 3 deletions.
190 changes: 190 additions & 0 deletions rankers/modelling/pyterrier/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import re
import base64
import string
import numpy as np
from contextlib import ExitStack
import itertools
from more_itertools import chunked
import torch
import pandas as pd
import pyterrier as pt
from transformers import AutoTokenizer
from pyterrier.model import add_ranks
from ..sparse import Sparse

"""
Taken from https://github.com/thongnt99/learned-sparse-retrieval/blob/main/SparseTransformer/transformer.py
```
"""

class SparseTransformer(pt.Transformer):
def __init__(self, model_name_or_path, device=None, batch_size=32, text_field='text', fp16=False, topk=None):
self.model_name_or_path = model_name_or_path
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.fp16 = fp16
self.device = device
self.model = Sparse.from_pretrained(model_name_or_path).eval().to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
all_token_ids = list(range(self.tokenizer.get_vocab_size()))
self.all_tokens = np.array(self.tokenizer.convert_ids_to_tokens(all_token_ids))
self.batch_size = batch_size
self.text_field = text_field
self.topk = topk

def encode_queries(self, texts, out_fmt='dict', topk=None):
outputs = []
if out_fmt != 'dict':
assert topk is None, "topk only supported when out_fmt='dict'"
with ExitStack() as stack:
stack.enter_context(torch.no_grad())
if self.fp16:
stack.enter_context(torch.cuda.amp.autocast())
for batch in chunked(texts, self.batch_size):
enc = self.tokenizer(batch, padding=True, truncation=True, return_special_tokens_mask=True, return_tensors="pt")
enc = {k: v.to(self.device) for k, v in enc.items()}
res = self.model.encode_queries(**enc).cpu().float()
if out_fmt == 'dict':
res = self.vec2dicts(res, topk=topk)
outputs.extend(res)
else:
outputs.append(res.numpy())
if out_fmt == 'np':
outputs = np.concatenate(outputs, axis=0)
elif out_fmt == 'np_list':
outputs = list(itertools.chain.from_iterable(outputs))
return outputs

def encode_docs(self, texts, out_fmt='dict', topk=None):
outputs = []
if out_fmt != 'dict':
assert topk is None, "topk only supported when out_fmt='dict'"
with ExitStack() as stack:
stack.enter_context(torch.no_grad())
if self.fp16:
stack.enter_context(torch.cuda.amp.autocast())
for batch in chunked(texts, self.batch_size):
enc = self.tokenizer(batch, padding=True, truncation=True, return_special_tokens_mask=True, return_tensors="pt")
enc = {k: v.to(self.device) for k, v in enc.items()}
res = self.model.encode_docs(**enc)
if out_fmt == 'dict':
res = self.vec2dicts(res, topk=topk)
outputs.extend(res)
else:
outputs.append(res.cpu().float().numpy())
if out_fmt == 'np':
outputs = np.concatenate(outputs, axis=0)
elif out_fmt == 'np_list':
outputs = list(itertools.chain.from_iterable(outputs))
return outputs

def vec2dicts(self, batch_output, topk=None):
rtr = []
idxs, cols = torch.nonzero(batch_output, as_tuple=True)
weights = batch_output[idxs, cols]
args = weights.argsort(descending=True)
idxs = idxs[args]
cols = cols[args]
weights = weights[args]
for i in range(batch_output.shape[0]):
mask = (idxs==i)
col = cols[mask]
w = weights[mask]
if topk is not None:
col = col[:topk]
w = w[:topk]
d = {self.all_tokens[k]: v for k, v in zip(col.cpu().tolist(), w.cpu().tolist())}
rtr.append(d)
return rtr

def query_encoder(self, matchop=False, sparse=True, topk=None):
return SparseQueryEncoder(self, matchop, sparse=sparse, topk=topk or self.topk)

def doc_encoder(self, text_field=None, sparse=True, topk=None):
return SparseDocEncoder(self, text_field or self.text_field, sparse=sparse, topk=topk or self.topk)

def scorer(self, text_field=None):
return SparseScorer(self, text_field or self.text_field)

def transform(self, inp):
if all(c in inp.columns for c in ['qid', 'query', self.text_field]):
return self.scorer()(inp)
elif 'query' in inp.columns:
return self.query_encoder()(inp)
elif self.text_field in inp.columns:
return self.doc_encoder()(inp)
raise ValueError(f'unsupported columns: {inp.columns}; expecting "query", {repr(self.text_field)}, or both.')


class SparseQueryEncoder(pt.Transformer):
def __init__(self, transformer : SparseTransformer, matchop=False, sparse=True, topk=None):
self.transformer = transformer
if not sparse:
assert not matchop, "matchop only supported when sparse=True"
assert topk is None, "topk only supported when sparse=True"
self.matchop = matchop
self.sparse = sparse
self.topk = topk

def encode(self, texts):
return self.transformer.encode_queries(texts, out_fmt='dict' if self.sparse else 'np_list', topk=self.topk)

def transform(self, inp):
res = self.encode(inp['query'])
if self.matchop:
res = [_matchop(r) for r in res]
inp = pt.model.push_queries(inp)
return inp.assign(query=res)
if self.sparse:
return inp.assign(query_toks=res)
return inp.assign(query_vec=res)


class SparseDocEncoder(pt.Transformer):
def __init__(self, transformer : SparseTransformer, text_field, sparse=True, topk=None):
self.transformer = transformer
self.text_field = text_field
self.sparse = sparse
if not sparse:
assert topk is None, "topk only supported when sparse=True"
self.topk = topk

def encode(self, texts):
return self.transformer.encode_docs(texts, out_fmt='dict' if self.sparse else 'np_list', topk=self.topk)

def transform(self, inp):
res = self.encode(inp[self.text_field])
if self.sparse:
return inp.assign(toks=res)
return inp.assign(doc_vec=res)


class SparseScorer(pt.Transformer):
def __init__(self, transformer: SparseTransformer, text_field):
self.transformer = transformer
self.text_field = text_field

def score(self, query_texts, doc_texts):
q, inv_q = np.unique(query_texts.values if isinstance(query_texts, pd.Series) else np.array(query_texts), return_inverse=True)
q = self.transformer.encode_queries(q, out_fmt='np')[inv_q]
d, inv_d = np.unique(doc_texts.values if isinstance(doc_texts, pd.Series) else np.array(doc_texts), return_inverse=True)
d = self.transformer.encode_docs(d, out_fmt='np')[inv_d]
return np.einsum('bd,bd->b', q, d)

def transform(self, inp):
res = inp.assign(score=self.score(inp['query'], inp[self.text_field]))
return add_ranks(res)


_alphnum_exp = re.compile('^[' + re.escape(string.ascii_letters + string.digits) + ']+$')

def _matchop(d):
res = []
for t, w in d.items():
if not _alphnum_exp.match(t):
encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8")
t = f'#base64({encoded})'
if w != 1:
t = f'#combine:0={w}({t})'
res.append(t)
return ' '.join(res)
10 changes: 7 additions & 3 deletions rankers/modelling/sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rankers.modelling.dot import Pooler
import torch
from transformers import AutoModel, AutoConfig
from transformers import AutoModel, AutoConfig, AutoModel, PreTrainedModel, PreTrainedTokenizer
from .dot import DotConfig, Dot

class SparseConfig(DotConfig):
Expand All @@ -15,8 +16,11 @@ class Sparse(Dot):
model_type = "Sparse"
transformer_class = None

def to_pyterrier(self) -> "SparseTransformer":
return self.transformer_class.from_model(self, self.tokenizer, text_field='text')
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, config: DotConfig, model_d: PreTrainedModel = None, pooler: Pooler = None):
super().__init__(model, tokenizer, config, model_d, pooler)

from .pyterrier.sparse import SparseTransformer
self.transformer_class = SparseTransformer

AutoConfig.register("Sparse", SparseConfig)
AutoModel.register(SparseConfig, Sparse)

0 comments on commit eb60442

Please sign in to comment.