From c7dbfe4dd3b423682ab249d32114900c7e871fe2 Mon Sep 17 00:00:00 2001 From: Andrew Parry Date: Wed, 6 Nov 2024 14:59:35 +0000 Subject: [PATCH] strange --- rankers/modelling/cat.py | 8 ++++---- rankers/modelling/dot.py | 20 ++++++++++---------- rankers/modelling/seq2seq.py | 6 +++--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/rankers/modelling/cat.py b/rankers/modelling/cat.py index bbc04c7..a3c19b7 100644 --- a/rankers/modelling/cat.py +++ b/rankers/modelling/cat.py @@ -190,11 +190,11 @@ def to_pyterrier(self) -> "pt.Transformer": return self.transformer_class.from_model(self.model, self.tokenizer, text_field='text') @classmethod - def from_pretrained(cls, model_dir_or_name : str, num_labels=2, config=None, **kwargs) -> "Cat": + def from_pretrained(cls, model_name_or_path : str, num_labels=2, config=None, **kwargs) -> "Cat": """Load model from a directory""" - config = cls.config_class.from_pretrained(model_dir_or_name, num_labels=num_labels) if config is None else config - model = cls.architecture_class.from_pretrained(model_dir_or_name, config=config, **kwargs) - tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name) + config = cls.config_class.from_pretrained(model_name_or_path, num_labels=num_labels) if config is None else config + model = cls.architecture_class.from_pretrained(model_name_or_path, config=config, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) return cls(model, tokenizer, config) AutoConfig.register("Cat", CatConfig) diff --git a/rankers/modelling/dot.py b/rankers/modelling/dot.py index 58f5c32..00c99e3 100644 --- a/rankers/modelling/dot.py +++ b/rankers/modelling/dot.py @@ -400,19 +400,19 @@ def load_state_dict(self, model_dir): if self.config.use_pooler: self.pooler.load_state_dict(self.architecture_class.from_pretrained(model_dir + "/pooler").state_dict()) @classmethod - def from_pretrained(cls, model_dir_or_name, config = None, **kwargs) -> "Dot": + def from_pretrained(cls, model_name_or_path, config = None, **kwargs) -> "Dot": """Load model""" - if os.path.isdir(model_dir_or_name): - config = cls.config_class.from_pretrained(model_dir_or_name) if config is None else config - model = cls.architecture_class.from_pretrained(model_dir_or_name, **kwargs) - tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name) - model_d = None if config.model_tied else cls.architecture_class.from_pretrained(model_dir_or_name + "/model_d", **kwargs) - pooler = None if not config.use_pooler else Pooler.from_pretrained(model_dir_or_name + "/pooler") + if os.path.isdir(model_name_or_path): + config = cls.config_class.from_pretrained(model_name_or_path) if config is None else config + model = cls.architecture_class.from_pretrained(model_name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + model_d = None if config.model_tied else cls.architecture_class.from_pretrained(model_name_or_path + "/model_d", **kwargs) + pooler = None if not config.use_pooler else Pooler.from_pretrained(model_name_or_path + "/pooler") return cls(model, tokenizer, config, model_d, pooler) - config = cls.config_class(model_dir_or_name, **kwargs) if config is None else config - tokenizer = AutoTokenizer.from_pretrained(model_dir_or_name) - model = cls.architecture_class.from_pretrained(model_dir_or_name) + config = cls.config_class(model_name_or_path, **kwargs) if config is None else config + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + model = cls.architecture_class.from_pretrained(model_name_or_path) return cls(model, tokenizer, config) def to_pyterrier(self) -> "DotTransformer": diff --git a/rankers/modelling/seq2seq.py b/rankers/modelling/seq2seq.py index c9e0ca5..8dbbf5e 100644 --- a/rankers/modelling/seq2seq.py +++ b/rankers/modelling/seq2seq.py @@ -176,12 +176,12 @@ def to_pyterrier(self) -> "Seq2SeqTransformer": @classmethod def from_pretrained(cls, - model_dir_or_name : str, + model_name_or_path : str, config : PreTrainedConfig = None, **kwargs): """Load model from a directory""" - config = cls.config_class.from_pretrained(model_dir_or_name) - model = cls.architecture_class.from_pretrained(model_dir_or_name, **kwargs) + config = cls.config_class.from_pretrained(model_name_or_path) + model = cls.architecture_class.from_pretrained(model_name_or_path, **kwargs) return cls(model, config) class CausalLMConfig(PreTrainedConfig):