From c972332b416424cce34ea12548f4457008c0a235 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 23 Oct 2024 23:48:01 -0700 Subject: [PATCH] For the simple use cases where no Enum or unexpected config objects are in the save files, use weights_only=True. Significantly cuts down on the number of torch warnings. https://github.com/stanfordnlp/stanza/issues/1429 --- stanza/models/charlm.py | 2 +- stanza/models/classifiers/trainer.py | 2 ++ stanza/models/common/char_model.py | 4 ++-- stanza/models/common/pretrain.py | 1 + stanza/models/common/trainer.py | 2 +- stanza/models/constituency/base_trainer.py | 3 +++ stanza/models/coref/model.py | 1 + stanza/models/depparse/trainer.py | 2 +- stanza/models/langid/model.py | 2 +- stanza/models/lemma/trainer.py | 2 +- stanza/models/mwt/trainer.py | 2 +- stanza/models/ner/data.py | 2 +- stanza/models/ner/trainer.py | 2 +- stanza/models/pos/trainer.py | 2 +- stanza/models/tokenization/trainer.py | 1 + stanza/tests/depparse/test_parser.py | 4 ++-- stanza/tests/lemma/test_lemma_trainer.py | 2 +- stanza/tests/ner/test_ner_training.py | 4 ++-- 18 files changed, 24 insertions(+), 16 deletions(-) diff --git a/stanza/models/charlm.py b/stanza/models/charlm.py index cea51a0e22..f394dcd822 100644 --- a/stanza/models/charlm.py +++ b/stanza/models/charlm.py @@ -206,7 +206,7 @@ def get_current_lr(trainer, args): return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0] def load_char_vocab(vocab_file): - return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage))} + return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))} def train(args): utils.log_training_args(args, logger) diff --git a/stanza/models/classifiers/trainer.py b/stanza/models/classifiers/trainer.py index 6c989a49f9..df73a7bb73 100644 --- a/stanza/models/classifiers/trainer.py +++ b/stanza/models/classifiers/trainer.py @@ -69,6 +69,8 @@ def load(filename, args, foundation_cache=None, load_optimizer=False): else: raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename))) try: + # TODO: switch to weights_only=True + # need to convert enums to int first checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: logger.exception("Cannot load model from {}".format(filename)) diff --git a/stanza/models/common/char_model.py b/stanza/models/common/char_model.py index 6caa42dd24..5deee5e0ab 100644 --- a/stanza/models/common/char_model.py +++ b/stanza/models/common/char_model.py @@ -268,7 +268,7 @@ def from_full_state(cls, state, finetune=False): @classmethod def load(cls, filename, finetune=False): - state = torch.load(filename, lambda storage, loc: storage) + state = torch.load(filename, lambda storage, loc: storage, weights_only=True) # allow saving just the Model object, # and allow for old charlms to still work if 'state_dict' in state: @@ -342,7 +342,7 @@ def load(cls, args, filename, finetune=False): Note that you MUST set finetune=True if planning to continue training Otherwise the only benefit you will get will be a warm GPU """ - state = torch.load(filename, lambda storage, loc: storage) + state = torch.load(filename, lambda storage, loc: storage, weights_only=True) model = CharacterLanguageModel.from_full_state(state['model'], finetune) model = model.to(args['device']) diff --git a/stanza/models/common/pretrain.py b/stanza/models/common/pretrain.py index 0253203c2c..7c7bafaccf 100644 --- a/stanza/models/common/pretrain.py +++ b/stanza/models/common/pretrain.py @@ -53,6 +53,7 @@ def emb(self): def load(self): if self.filename is not None and os.path.exists(self.filename): try: + # TODO: update all pretrains to satisfy weights_only=True data = torch.load(self.filename, lambda storage, loc: storage) logger.debug("Loaded pretrain from {}".format(self.filename)) if not isinstance(data, dict): diff --git a/stanza/models/common/trainer.py b/stanza/models/common/trainer.py index a4ebd1b017..92d0ad83bc 100644 --- a/stanza/models/common/trainer.py +++ b/stanza/models/common/trainer.py @@ -13,7 +13,7 @@ def save(self, filename): torch.save(savedict, filename) def load(self, filename): - savedict = torch.load(filename, lambda storage, loc: storage) + savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True) self.model.load_state_dict(savedict['model']) if self.args['mode'] == 'train': diff --git a/stanza/models/constituency/base_trainer.py b/stanza/models/constituency/base_trainer.py index ff0f3a2415..d5e10c2ad8 100644 --- a/stanza/models/constituency/base_trainer.py +++ b/stanza/models/constituency/base_trainer.py @@ -84,6 +84,9 @@ def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_ else: raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename))) try: + # TODO: currently cannot switch this to weights_only=True + # without in some way changing the model to save enums in + # a safe manner, probably by converting to int checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: logger.exception("Cannot load model from %s", filename) diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index dcf835c0c0..73dc700b81 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -224,6 +224,7 @@ def load_weights(self, if map_location is None: map_location = self.config.device logger.debug(f"Loading from {path}...") + # TODO: the config is preventing us from using weights_only=True state_dicts = torch.load(path, map_location=map_location) self.epochs_trained = state_dicts.pop("epochs_trained", 0) # just ignore a config in the model, since we should already have one diff --git a/stanza/models/depparse/trainer.py b/stanza/models/depparse/trainer.py index fb4e4ad580..46e70befb9 100644 --- a/stanza/models/depparse/trainer.py +++ b/stanza/models/depparse/trainer.py @@ -191,7 +191,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args. """ try: - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise diff --git a/stanza/models/langid/model.py b/stanza/models/langid/model.py index e33d7546c9..f6219b3879 100644 --- a/stanza/models/langid/model.py +++ b/stanza/models/langid/model.py @@ -114,7 +114,7 @@ def load(cls, path, device=None, batch_size=64, lang_subset=None): raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name") if not os.path.exists(path): raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path) - checkpoint = torch.load(path, map_location=torch.device("cpu")) + checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True) weights = checkpoint["model_state_dict"]["loss_train.weight"] model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"], checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights, diff --git a/stanza/models/lemma/trainer.py b/stanza/models/lemma/trainer.py index 4b5e4a0b74..22a0ae1b48 100644 --- a/stanza/models/lemma/trainer.py +++ b/stanza/models/lemma/trainer.py @@ -236,7 +236,7 @@ def save(self, filename, skip_modules=True): def load(self, filename, args, foundation_cache): try: - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise diff --git a/stanza/models/mwt/trainer.py b/stanza/models/mwt/trainer.py index 58f01ec10f..090df806da 100644 --- a/stanza/models/mwt/trainer.py +++ b/stanza/models/mwt/trainer.py @@ -198,7 +198,7 @@ def save(self, filename): def load(self, filename): try: - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise diff --git a/stanza/models/ner/data.py b/stanza/models/ner/data.py index 242df1f19e..37cf7c9e7f 100644 --- a/stanza/models/ner/data.py +++ b/stanza/models/ner/data.py @@ -55,7 +55,7 @@ def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation= def init_vocab(self, data): def from_model(model_filename): """ Try loading vocab from charLM model file. """ - state_dict = torch.load(model_filename, lambda storage, loc: storage) + state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True) if 'vocab' in state_dict: return state_dict['vocab'] if 'model' in state_dict and 'vocab' in state_dict['model']: diff --git a/stanza/models/ner/trainer.py b/stanza/models/ner/trainer.py index 108f91227e..8137336b97 100644 --- a/stanza/models/ner/trainer.py +++ b/stanza/models/ner/trainer.py @@ -194,7 +194,7 @@ def save(self, filename, skip_modules=True): def load(self, filename, pretrain=None, args=None, foundation_cache=None): try: - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise diff --git a/stanza/models/pos/trainer.py b/stanza/models/pos/trainer.py index c7c04e7aa5..6ec8fbf161 100644 --- a/stanza/models/pos/trainer.py +++ b/stanza/models/pos/trainer.py @@ -136,7 +136,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None): and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args. """ try: - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 254f419f92..b06dd4e4b2 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -79,6 +79,7 @@ def save(self, filename): def load(self, filename): try: + # the tokenizers with dictionaries won't properly load weights_only=True because they have a set checkpoint = torch.load(filename, lambda storage, loc: storage) except BaseException: logger.error("Cannot load model from {}".format(filename)) diff --git a/stanza/tests/depparse/test_parser.py b/stanza/tests/depparse/test_parser.py index 4d0a6ad1f9..8dcc9cf5b5 100644 --- a/stanza/tests/depparse/test_parser.py +++ b/stanza/tests/depparse/test_parser.py @@ -145,7 +145,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file): save_name = trainer.args['save_name'] filename = tmp_path / save_name assert os.path.exists(filename) - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) assert any(x.startswith("bert_model") for x in checkpoint['model'].keys()) # Test loading the saved model, saving it, and still having bert in it @@ -157,7 +157,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file): saved_model.save(filename) # This is the part that would fail if the force_bert_saved option did not exist - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) assert any(x.startswith("bert_model") for x in checkpoint['model'].keys()) def test_with_peft(self, tmp_path, wordvec_pretrain_file): diff --git a/stanza/tests/lemma/test_lemma_trainer.py b/stanza/tests/lemma/test_lemma_trainer.py index ef0cef0595..cca7f0d21f 100644 --- a/stanza/tests/lemma/test_lemma_trainer.py +++ b/stanza/tests/lemma/test_lemma_trainer.py @@ -150,5 +150,5 @@ def test_charlm_train(self, tmp_path, charlm_args): # check that the charlm wasn't saved in here args = saved_model.args save_name = os.path.join(args['save_dir'], args['save_name']) - checkpoint = torch.load(save_name, lambda storage, loc: storage) + checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True) assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys()) diff --git a/stanza/tests/ner/test_ner_training.py b/stanza/tests/ner/test_ner_training.py index f262c32ea3..e9b4d69ae3 100644 --- a/stanza/tests/ner/test_ner_training.py +++ b/stanza/tests/ner/test_ner_training.py @@ -226,7 +226,7 @@ def test_train_model_cpu(pretrain_file, tmp_path): assert str(device).startswith("cpu") def model_file_has_bert(filename): - checkpoint = torch.load(filename, lambda storage, loc: storage) + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) return any(x.startswith("bert_model.") for x in checkpoint['model'].keys()) def test_with_bert(pretrain_file, tmp_path): @@ -253,7 +253,7 @@ def test_with_peft_finetune(pretrain_file, tmp_path): # TODO: check that the peft tensors are moving when training? trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft') model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name']) - checkpoint = torch.load(model_file, lambda storage, loc: storage) + checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True) assert 'bert_lora' in checkpoint assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())