Skip to content

Commit

Permalink
minor changes for convienience
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 5, 2024
1 parent 20a3a22 commit ee97e27
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
9 changes: 6 additions & 3 deletions rankers/modelling/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class CatTransformer(pt.Transformer):
cls_architecture = AutoModelForSequenceClassification
cls_config = AutoConfig
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
Expand Down Expand Up @@ -40,7 +41,7 @@ def from_pretrained(cls,
verbose : bool = False,
**kwargs
):
config = AutoConfig.from_pretrained(model_name_or_path) if config is None else config
config = cls.cls_config.from_pretrained(model_name_or_path) if config is None else config
model = cls.cls_architecture.from_pretrained(model_name_or_path, config=config, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config, batch_size, text_field, device, verbose)
Expand Down Expand Up @@ -73,6 +74,7 @@ def transform(self, inp : pd.DataFrame) -> pd.DataFrame:

class PairTransformer(pt.Transformer):
cls_architecture = AutoModelForSequenceClassification
cls_config = AutoConfig
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
Expand Down Expand Up @@ -112,7 +114,7 @@ def from_pretrained(cls,
verbose : bool = False,
**kwargs
):
config = AutoConfig.from_pretrained(model_name_or_path) if config is None else config
config = cls.cls_config.from_pretrained(model_name_or_path) if config is None else config
model = cls.cls_architecture.from_pretrained(model_name_or_path, config=config, **kwargs).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config, batch_size, text_field, device, verbose)
Expand Down Expand Up @@ -145,6 +147,7 @@ class Cat(PreTrainedModel):
"""
model_architecture = 'Cat'
cls_architecture = AutoModelForSequenceClassification
cls_config = AutoConfig
transformer_architecture = CatTransformer
def __init__(
self,
Expand Down Expand Up @@ -185,7 +188,7 @@ def to_pyterrier(self) -> "pt.Transformer":
@classmethod
def from_pretrained(cls, model_dir_or_name : str, num_labels=2, config=None, **kwargs) -> "Cat":
"""Load model from a directory"""
config = AutoConfig.from_pretrained(model_dir_or_name) if config is None else config
config = cls.cls_config.from_pretrained(model_dir_or_name) if config is None else config
model = cls.cls_architecture.from_pretrained(model_dir_or_name, num_labels=num_labels, config=config **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
return cls(model, tokenizer, config)
8 changes: 5 additions & 3 deletions rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def from_pretrained(cls,

class DotTransformer(pt.Transformer):
cls_architecture = AutoModel
cls_config = DotConfig
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
Expand Down Expand Up @@ -111,7 +112,7 @@ def from_pretrained(cls,
verbose : bool = False,
**kwargs
):
config = DotConfig.from_pretrained(model_name_or_path) if config is None else config
config = cls.cls_config.from_pretrained(model_name_or_path) if config is None else config
config.mode = pooling
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path+"/pooler")
model_d = None if config.model_tied else cls.cls_architecture.from_pretrained(model_name_or_path + "/model_d", **kwargs)
Expand Down Expand Up @@ -288,6 +289,7 @@ class Dot(PreTrainedModel):
"""
model_architecture = 'Dot'
cls_architecture = AutoModel
cls_config = DotConfig
transformer_architecture = DotTransformer
def __init__(
self,
Expand Down Expand Up @@ -386,14 +388,14 @@ def load_state_dict(self, model_dir):
def from_pretrained(cls, model_dir_or_name, config = None, **kwargs) -> "Dot":
"""Load model"""
if os.path.isdir(model_dir_or_name):
config = DotConfig.from_pretrained(model_dir_or_name) if config is None else config
config = cls.cls_config.from_pretrained(model_dir_or_name) if config is None else config
model = cls.cls_architecture.from_pretrained(model_dir_or_name, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
model_d = None if config.model_tied else cls.cls_architecture.from_pretrained(model_dir_or_name + "/model_d", **kwargs)
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_dir_or_name + "/pooler")

return cls(model, tokenizer, config, model_d, pooler)
config = DotConfig(model_dir_or_name, **kwargs) if config is None else config
config = cls.cls_config(model_dir_or_name, **kwargs) if config is None else config
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
model = cls.cls_architecture.from_pretrained(model_dir_or_name)
return cls(model, tokenizer, config)
Expand Down
25 changes: 16 additions & 9 deletions rankers/modelling/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pyterrier as pt
if not pt.started():
pt.init()
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import PreTrainedModel, PreTrainedConfig, PreTrainedTokenizer, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from typing import Union
import torch
import pandas as pd
Expand All @@ -17,7 +17,7 @@ class Seq2SeqTransformer(pt.Transformer):
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : AutoConfig,
config : PreTrainedConfig,
batch_size : int,
text_field : str = 'text',
device : Union[str, torch.device] = None,
Expand Down Expand Up @@ -85,7 +85,7 @@ class Seq2SeqDuoTransformer(Seq2SeqTransformer):
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : AutoConfig,
config : PreTrainedConfig,
batch_size : int,
text_field : str = 'text',
device : Union[str, torch.device] = None,
Expand Down Expand Up @@ -127,11 +127,13 @@ class Seq2Seq(PreTrainedModel):
"""
model_architecture = 'Seq2Seq'
cls_architecture = AutoModelForSeq2SeqLM
cls_config = AutoConfig
transformer_architecture = Seq2SeqTransformer
def __init__(
self,
model: AutoModelForSeq2SeqLM,
tokenizer: PreTrainedTokenizer,
config: AutoConfig,
config: PreTrainedConfig,
):
super().__init__(config)
self.model = model
Expand Down Expand Up @@ -161,21 +163,25 @@ def load_state_dict(self, model_dir):
return self.model.load_state_dict(AutoModelForSeq2SeqLM.from_pretrained(model_dir).state_dict())

def to_pyterrier(self) -> "Seq2SeqTransformer":
return Seq2SeqTransformer.from_model(self.model, self.tokenizer, text_field='text')
return self.transformer_architecture.from_model(self.model, self.tokenizer, text_field='text')

@classmethod
def from_pretrained(cls, model_dir_or_name : str, **kwargs):
def from_pretrained(cls,
model_dir_or_name : str,
config : PreTrainedConfig = None,
**kwargs):
"""Load model from a directory"""
config = AutoConfig.from_pretrained(model_dir_or_name)
config = cls.cls_config.from_pretrained(model_dir_or_name)
model = cls.cls_architecture.from_pretrained(model_dir_or_name, **kwargs)
return cls(model, config)

class CausalLMTransformer(Seq2SeqTransformer):
cls_architecture = AutoModelForCausalLM
cls_config = AutoConfig
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : AutoConfig,
config : PreTrainedConfig,
batch_size : int,
text_field : str = 'text',
device : Union[str, torch.device] = None,
Expand All @@ -190,14 +196,15 @@ def from_pretrained(cls,
model_name_or_path : str,
batch_size : int = 64,
text_field : str = 'text',
config : PreTrainedConfig = None,
device : Union[str, torch.device] = None,
prompt : str = None,
verbose : bool = False,
**kwargs
):
config = cls.cls_config.from_pretrained(model_name_or_path) if config is None else config
model = cls.cls_architecture.from_pretrained(model_name_or_path, **kwargs).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config, batch_size, text_field, device, prompt, verbose=verbose)

@classmethod
Expand Down

0 comments on commit ee97e27

Please sign in to comment.