Skip to content

Commit

Permalink
unifying naming convention
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 4, 2024
1 parent ea62e31 commit 04ece91
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 93 deletions.
29 changes: 14 additions & 15 deletions rankers/modelling/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ class Cat(PreTrainedModel):
Parameters
----------
classifier : PreTrainedModel
the classifier model
model : PreTrainedModel
the underlying HF model
config : AutoConfig
the configuration for the model
"""
model_architecture = 'Cat'
def __init__(
self,
classifier: PreTrainedModel,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
config: AutoConfig,
):
super().__init__(config)
self.classifier = classifier
self.model = model
self.tokenizer = tokenizer

def prepare_outputs(self, logits, labels=None):
Expand All @@ -37,35 +37,34 @@ def prepare_outputs(self, logits, labels=None):

def forward(self, loss, sequences, labels=None):
"""Compute the loss given (pairs, labels)"""
sequences = {k: v.to(self.classifier.device) for k, v in sequences.items()}
labels = labels.to(self.classifier.device) if labels is not None else None
logits = self.classifier(**sequences).logits
sequences = {k: v.to(self.model.device) for k, v in sequences.items()}
labels = labels.to(self.model.device) if labels is not None else None
logits = self.model(**sequences).logits
pred, labels = self.prepare_outputs(logits, labels)
loss_value = loss(pred) if labels is None else loss(pred, labels)
return (loss_value, pred)

def save_pretrained(self, model_dir, **kwargs):
"""Save classifier"""
"""Save model"""
self.config.save_pretrained(model_dir)
self.classifier.save_pretrained(model_dir)
self.model.save_pretrained(model_dir)
self.tokenizer.save_pretrained(model_dir)


def load_state_dict(self, model_dir):
"""Load state dict from a directory"""
return self.classifier.load_state_dict(AutoModelForSequenceClassification.from_pretrained(model_dir).state_dict())

return self.model.load_state_dict(AutoModelForSequenceClassification.from_pretrained(model_dir).state_dict())

def to_pyterrier(self) -> "pt.Transformer":
return CatTransformer.from_model(self.classifier, self.tokenizer, text_field='text')
return CatTransformer.from_model(self.model, self.tokenizer, text_field='text')

@classmethod
def from_pretrained(cls, model_dir_or_name : str, num_labels=2):
"""Load classifier from a directory"""
"""Load model from a directory"""
config = AutoConfig.from_pretrained(model_dir_or_name)
classifier = AutoModelForSequenceClassification.from_pretrained(model_dir_or_name, num_labels=num_labels)
model = AutoModelForSequenceClassification.from_pretrained(model_dir_or_name, num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
return cls(classifier, tokenizer, config)
return cls(model, tokenizer, config)

class CatTransformer(pt.Transformer):
def __init__(self,
Expand Down
Empty file removed rankers/modelling/causallm.py
Empty file.
126 changes: 63 additions & 63 deletions rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class DotConfig(PretrainedConfig):
the model name or path
mode : str
the pooling mode for the model
encoder_tied : bool
whether the encoder is tied
model_tied : bool
whether the model is tied
use_pooler : bool
whether to use the pooler
pooler_dim_in : int
Expand All @@ -37,7 +37,7 @@ def __init__(self,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
encoder_tied=True,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
Expand All @@ -46,7 +46,7 @@ def __init__(self,
self.model_name_or_path = model_name_or_path
self.mode = mode
self.inbatch_loss = inbatch_loss
self.encoder_tied = encoder_tied
self.model_tied = model_tied
self.use_pooler = use_pooler
self.pooler_dim_in = pooler_dim_in
self.pooler_dim_out = pooler_dim_out
Expand All @@ -58,7 +58,7 @@ def from_pretrained(cls,
model_name_or_path : str='bert-base-uncased',
mode='cls',
inbatch_loss=None,
encoder_tied=True,
model_tied=True,
use_pooler=False,
pooler_dim_in=768,
pooler_dim_out=768,
Expand All @@ -68,7 +68,7 @@ def from_pretrained(cls,
config.model_name_or_path = model_name_or_path
config.mode = mode
config.inbatch_loss = inbatch_loss
config.encoder_tied = encoder_tied
config.model_tied = model_tied
config.use_pooler = use_pooler
config.pooler_dim_in = pooler_dim_in
config.pooler_dim_out = pooler_dim_out
Expand Down Expand Up @@ -96,29 +96,29 @@ class Dot(PreTrainedModel):
Parameters
----------
encoder : PreTrainedModel
the encoder model
model : PreTrainedModel
the model model
config : DotConfig
the configuration for the model
encoder_d : PreTrainedModel
the document encoder model
model_d : PreTrainedModel
the document model model
pooler : Pooler
the pooling layer
"""
model_architecture = 'Dot'
def __init__(
self,
encoder : PreTrainedModel,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : DotConfig,
encoder_d : PreTrainedModel = None,
model_d : PreTrainedModel = None,
pooler : Pooler = None,
):
super().__init__(config)
self.encoder = encoder
self.model = model
self.tokenizer = tokenizer
if encoder_d: self.encoder_d = encoder_d
else: self.encoder_d = self.encoder if config.encoder_tied else deepcopy(self.encoder)
if model_d: self.model_d = model_d
else: self.model_d = self.model if config.model_tied else deepcopy(self.model)
self.pooling = {
'mean': self._mean,
'cls' : self._cls,
Expand Down Expand Up @@ -158,20 +158,20 @@ def _mean(self, x : torch.Tensor) -> torch.Tensor:
return self.pooler(x.mean(dim=1))

def _encode_d(self, **text):
return self.pooling(self.encoder_d(**text).last_hidden_state)
return self.pooling(self.model_d(**text).last_hidden_state)

def _encode_q(self, **text):
return self.pooling(self.encoder(**text).last_hidden_state)
return self.pooling(self.model(**text).last_hidden_state)

def forward(self,
loss = None,
queries = None,
docs_batch = None,
labels=None):
"""Compute the loss given (queries, docs, labels)"""
queries = {k: v.to(self.encoder.device) for k, v in queries.items()} if queries is not None else None
docs_batch = {k: v.to(self.encoder_d.device) for k, v in docs_batch.items()} if docs_batch is not None else None
labels = labels.to(self.encoder_d.device) if labels is not None else None
queries = {k: v.to(self.model.device) for k, v in queries.items()} if queries is not None else None
docs_batch = {k: v.to(self.model_d.device) for k, v in docs_batch.items()} if docs_batch is not None else None
labels = labels.to(self.model_d.device) if labels is not None else None

query_reps = self._encode_q(**queries) if queries is not None else None
docs_batch_reps = self._encode_d(**docs_batch) if docs_batch is not None else None
Expand All @@ -184,36 +184,36 @@ def forward(self,
return (loss_value, pred)

def save_pretrained(self, model_dir, **kwargs):
"""Save both query and document encoder"""
"""Save both query and document model"""
self.config.save_pretrained(model_dir)
self.encoder.save_pretrained(model_dir)
if not self.config.encoder_tied: self.encoder_d.save_pretrained(model_dir + "/encoder_d")
self.model.save_pretrained(model_dir)
if not self.config.model_tied: self.model_d.save_pretrained(model_dir + "/model_d")
if self.config.use_pooler: self.pooler.save_pretrained(model_dir + "/pooler")
self.tokenizer.save_pretrained(model_dir)


def load_state_dict(self, model_dir):
"""Load state dict from a directory"""
self.config = DotConfig.from_pretrained(model_dir)
self.encoder.load_state_dict(AutoModel.from_pretrained(model_dir).state_dict())
if not self.config.encoder_tied: self.encoder_d.load_state_dict(AutoModel.from_pretrained(model_dir + "/encoder_d").state_dict())
self.model.load_state_dict(AutoModel.from_pretrained(model_dir).state_dict())
if not self.config.model_tied: self.model_d.load_state_dict(AutoModel.from_pretrained(model_dir + "/model_d").state_dict())
if self.config.use_pooler: self.pooler.load_state_dict(AutoModel.from_pretrained(model_dir + "/pooler").state_dict())

@classmethod
def from_pretrained(cls, model_dir_or_name, **kwargs):
"""Load encoder"""
"""Load model"""
if os.path.isdir(model_dir_or_name):
config = DotConfig.from_pretrained(model_dir_or_name, **kwargs)
encoder = AutoModel.from_pretrained(model_dir_or_name)
model = AutoModel.from_pretrained(model_dir_or_name)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
encoder_d = None if config.encoder_tied else AutoModel.from_pretrained(model_dir_or_name + "/encoder_d")
model_d = None if config.model_tied else AutoModel.from_pretrained(model_dir_or_name + "/model_d")
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_dir_or_name + "/pooler")

return cls(encoder, tokenizer, config, encoder_d, pooler)
return cls(model, tokenizer, config, model_d, pooler)
config = DotConfig(model_dir_or_name, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
encoder = AutoModel.from_pretrained(model_dir_or_name)
return cls(encoder, tokenizer, config)
model = AutoModel.from_pretrained(model_dir_or_name)
return cls(model, tokenizer, config)

def to_pyterrier(self) -> "DotTransformer":
return DotTransformer.from_model(self, self.tokenizer, text_field='text')
Expand Down Expand Up @@ -254,9 +254,9 @@ def from_pretrained(cls,
config = DotConfig.from_pretrained(model_name_or_path)
config.mode = pooling
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path+"/pooler")
encoder_d = None if config.encoder_tied else AutoModel.from_pretrained(model_name_or_path + "/encoder_d")
encoder_q = AutoModel.from_pretrained(model_name_or_path)
model = Dot(encoder_q, config, encoder_d, pooler)
model_d = None if config.model_tied else AutoModel.from_pretrained(model_name_or_path + "/model_d")
model_q = AutoModel.from_pretrained(model_name_or_path)
model = Dot(model_q, config, model_d, pooler)
return cls(model, AutoTokenizer.from_pretrained(model_name_or_path), config, batch_size, text_field, device, verbose)

@classmethod
Expand Down Expand Up @@ -301,8 +301,8 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
(['qid', 'query_vec', self.text_field], self.scorer),
(['qid', 'query', 'doc_vec'], self.scorer),
(['qid', 'query_vec', 'doc_vec'], self.scorer),
(['query'], self.query_encoder),
([self.text_field], self.doc_encoder),
(['query'], self.query_model),
([self.text_field], self.doc_model),
]
for fields, fn in modes:
if all(f in columns for f in fields):
Expand All @@ -312,32 +312,32 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
message += f'\n - {fn.__doc__.strip()}: {fields}'
raise RuntimeError(message)

def query_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
def query_model(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Query encoding
"""
return BiQueryEncoder(self, verbose=verbose, batch_size=batch_size)
return BiQuerymodel(self, verbose=verbose, batch_size=batch_size)

def doc_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
def doc_model(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Doc encoding
"""
return BiDocEncoder(self, verbose=verbose, batch_size=batch_size)
return BiDocmodel(self, verbose=verbose, batch_size=batch_size)

def scorer(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
Scoring (re-ranking)
"""
return BiScorer(self, verbose=verbose, batch_size=batch_size)

class BiQueryEncoder(pt.Transformer):
def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size
class BiQuerymodel(pt.Transformer):
def __init__(self, bi_model_model: DotTransformer, verbose=None, batch_size=None):
self.bi_model_model = bi_model_model
self.verbose = verbose if verbose is not None else bi_model_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_model_model.batch_size

def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_queries(texts, batch_size=batch_size or self.batch_size)
return self.bi_model_model.encode_queries(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])
Expand All @@ -349,17 +349,17 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
return inp.assign(query_vec=[enc[i] for i in inv])

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.query_encoder()'
return f'{repr(self.bi_model_model)}.query_model()'

class BiDocEncoder(pt.Transformer):
def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None, text_field=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size
self.text_field = text_field if text_field is not None else bi_encoder_model.text_field
class BiDocmodel(pt.Transformer):
def __init__(self, bi_model_model: DotTransformer, verbose=None, batch_size=None, text_field=None):
self.bi_model_model = bi_model_model
self.verbose = verbose if verbose is not None else bi_model_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_model_model.batch_size
self.text_field = text_field if text_field is not None else bi_model_model.text_field

def encode(self, texts, batch_size=None) -> np.array:
return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size)
return self.bi_model_model.encode_docs(texts, batch_size=batch_size or self.batch_size)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in [self.text_field])
Expand All @@ -369,29 +369,29 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
return inp.assign(doc_vec=list(self.encode(it)))

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.doc_encoder()'
return f'{repr(self.bi_model_model)}.doc_model()'

class BiScorer(pt.Transformer):
def __init__(self, bi_encoder_model: DotTransformer, verbose=None, batch_size=None, text_field=None):
self.bi_encoder_model = bi_encoder_model
self.verbose = verbose if verbose is not None else bi_encoder_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_encoder_model.batch_size
self.text_field = text_field if text_field is not None else bi_encoder_model.text_field
def __init__(self, bi_model_model: DotTransformer, verbose=None, batch_size=None, text_field=None):
self.bi_model_model = bi_model_model
self.verbose = verbose if verbose is not None else bi_model_model.verbose
self.batch_size = batch_size if batch_size is not None else bi_model_model.batch_size
self.text_field = text_field if text_field is not None else bi_model_model.text_field

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert 'query_vec' in inp.columns or 'query' in inp.columns
assert 'doc_vec' in inp.columns or self.text_field in inp.columns
if 'query_vec' in inp.columns:
query_vec = inp['query_vec']
else:
query_vec = self.bi_encoder_model.query_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['query_vec']
query_vec = self.bi_model_model.query_model(batch_size=self.batch_size, verbose=self.verbose)(inp)['query_vec']
if 'doc_vec' in inp.columns:
doc_vec = inp['doc_vec']
else:
doc_vec = self.bi_encoder_model.doc_encoder(batch_size=self.batch_size, verbose=self.verbose)(inp)['doc_vec']
doc_vec = self.bi_model_model.doc_model(batch_size=self.batch_size, verbose=self.verbose)(inp)['doc_vec']
scores = (query_vec * doc_vec).apply(np.sum)
outp = inp.assign(score=scores)
return pt.model.add_ranks(outp)

def __repr__(self):
return f'{repr(self.bi_encoder_model)}.scorer()'
return f'{repr(self.bi_model_model)}.scorer()'
Loading

0 comments on commit 04ece91

Please sign in to comment.