Skip to content

Commit

Permalink
all fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 5, 2024
1 parent a9584d7 commit b7987b4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
22 changes: 12 additions & 10 deletions rankers/modelling/cat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pyterrier as pt
if not pt.started():
pt.init()
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from transformers import PreTrainedModel, PreTrainedConfig, PreTrainedTokenizer, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from typing import Union
import torch
import pandas as pd
Expand All @@ -14,7 +14,7 @@ class CatTransformer(pt.Transformer):
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : AutoConfig,
config : PreTrainedConfig,
batch_size : int,
text_field : str = 'text',
device : Union[str, torch.device] = None,
Expand All @@ -35,13 +35,14 @@ def from_pretrained(cls,
model_name_or_path : str,
batch_size : int = 64,
text_field : str = 'text',
config : PreTrainedConfig = None,
device : Union[str, torch.device] = None,
verbose : bool = False,
**kwargs
):
model = cls.cls_architecture.from_pretrained(model_name_or_path, **kwargs)
config = AutoConfig.from_pretrained(model_name_or_path) if config is None else config
model = cls.cls_architecture.from_pretrained(model_name_or_path, config=config, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config, batch_size, text_field, device, verbose)

@classmethod
Expand Down Expand Up @@ -75,7 +76,7 @@ class PairTransformer(pt.Transformer):
def __init__(self,
model : PreTrainedModel,
tokenizer : PreTrainedTokenizer,
config : AutoConfig,
config : PreTrainedConfig,
batch_size : int,
text_field : str = 'text',
device : Union[str, torch.device] = None,
Expand Down Expand Up @@ -106,13 +107,14 @@ def from_pretrained(cls,
model_name_or_path : str,
batch_size : int = 64,
text_field : str = 'text',
config : PreTrainedConfig = None,
device : Union[str, torch.device] = None,
verbose : bool = False,
**kwargs
):
model = cls.cls_architecture.from_pretrained(model_name_or_path, **kwargs).cuda().eval()
config = AutoConfig.from_pretrained(model_name_or_path) if config is None else config
model = cls.cls_architecture.from_pretrained(model_name_or_path, config=config, **kwargs).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config, batch_size, text_field, device, verbose)

def transform(self, inp : pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -181,9 +183,9 @@ def to_pyterrier(self) -> "pt.Transformer":
return self.transformer_architecture.from_model(self.model, self.tokenizer, text_field='text')

@classmethod
def from_pretrained(cls, model_dir_or_name : str, num_labels=2, **kwargs) -> "Cat":
def from_pretrained(cls, model_dir_or_name : str, num_labels=2, config=None, **kwargs) -> "Cat":
"""Load model from a directory"""
config = AutoConfig.from_pretrained(model_dir_or_name)
model = cls.cls_architecture.from_pretrained(model_dir_or_name, num_labels=num_labels, **kwargs)
config = AutoConfig.from_pretrained(model_dir_or_name) if config is None else config
model = cls.cls_architecture.from_pretrained(model_dir_or_name, num_labels=num_labels, config=config **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
return cls(model, tokenizer, config)
3 changes: 2 additions & 1 deletion rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,13 @@ def from_pretrained(cls,
model_name_or_path : str,
batch_size : int = 64,
pooling : str = 'cls',
config : PretrainedConfig = None,
text_field : str = 'text',
device : Union[str, torch.device] = None,
verbose : bool = False,
**kwargs
):
config = DotConfig.from_pretrained(model_name_or_path)
config = DotConfig.from_pretrained(model_name_or_path) if config is None else config
config.mode = pooling
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path+"/pooler")
model_d = None if config.model_tied else cls.cls_architecture.from_pretrained(model_name_or_path + "/model_d", **kwargs)
Expand Down

0 comments on commit b7987b4

Please sign in to comment.