Skip to content

Commit

Permalink
strange
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 6, 2024
1 parent f456eee commit c7dbfe4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions rankers/modelling/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def to_pyterrier(self) -> "pt.Transformer":
return self.transformer_class.from_model(self.model, self.tokenizer, text_field='text')

@classmethod
def from_pretrained(cls, model_dir_or_name : str, num_labels=2, config=None, **kwargs) -> "Cat":
def from_pretrained(cls, model_name_or_path : str, num_labels=2, config=None, **kwargs) -> "Cat":
"""Load model from a directory"""
config = cls.config_class.from_pretrained(model_dir_or_name, num_labels=num_labels) if config is None else config
model = cls.architecture_class.from_pretrained(model_dir_or_name, config=config, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
config = cls.config_class.from_pretrained(model_name_or_path, num_labels=num_labels) if config is None else config
model = cls.architecture_class.from_pretrained(model_name_or_path, config=config, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config)

AutoConfig.register("Cat", CatConfig)
Expand Down
20 changes: 10 additions & 10 deletions rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,19 +400,19 @@ def load_state_dict(self, model_dir):
if self.config.use_pooler: self.pooler.load_state_dict(self.architecture_class.from_pretrained(model_dir + "/pooler").state_dict())

@classmethod
def from_pretrained(cls, model_dir_or_name, config = None, **kwargs) -> "Dot":
def from_pretrained(cls, model_name_or_path, config = None, **kwargs) -> "Dot":
"""Load model"""
if os.path.isdir(model_dir_or_name):
config = cls.config_class.from_pretrained(model_dir_or_name) if config is None else config
model = cls.architecture_class.from_pretrained(model_dir_or_name, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
model_d = None if config.model_tied else cls.architecture_class.from_pretrained(model_dir_or_name + "/model_d", **kwargs)
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_dir_or_name + "/pooler")
if os.path.isdir(model_name_or_path):
config = cls.config_class.from_pretrained(model_name_or_path) if config is None else config
model = cls.architecture_class.from_pretrained(model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model_d = None if config.model_tied else cls.architecture_class.from_pretrained(model_name_or_path + "/model_d", **kwargs)
pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path + "/pooler")

return cls(model, tokenizer, config, model_d, pooler)
config = cls.config_class(model_dir_or_name, **kwargs) if config is None else config
tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name)
model = cls.architecture_class.from_pretrained(model_dir_or_name)
config = cls.config_class(model_name_or_path, **kwargs) if config is None else config
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = cls.architecture_class.from_pretrained(model_name_or_path)
return cls(model, tokenizer, config)

def to_pyterrier(self) -> "DotTransformer":
Expand Down
6 changes: 3 additions & 3 deletions rankers/modelling/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ def to_pyterrier(self) -> "Seq2SeqTransformer":

@classmethod
def from_pretrained(cls,
model_dir_or_name : str,
model_name_or_path : str,
config : PreTrainedConfig = None,
**kwargs):
"""Load model from a directory"""
config = cls.config_class.from_pretrained(model_dir_or_name)
model = cls.architecture_class.from_pretrained(model_dir_or_name, **kwargs)
config = cls.config_class.from_pretrained(model_name_or_path)
model = cls.architecture_class.from_pretrained(model_name_or_path, **kwargs)
return cls(model, config)

class CausalLMConfig(PreTrainedConfig):
Expand Down

0 comments on commit c7dbfe4

Please sign in to comment.