Skip to content

Commit

Permalink
For the simple use cases where no Enum or unexpected config objects a…
Browse files Browse the repository at this point in the history
…re in the save files, use weights_only=True. Significantly cuts down on the number of torch warnings. #1429
  • Loading branch information
AngledLuffa committed Oct 24, 2024
1 parent b06eef1 commit 80be642
Show file tree
Hide file tree
Showing 19 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion stanza/models/charlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_current_lr(trainer, args):
return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]

def load_char_vocab(vocab_file):
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage))}
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}

def train(args):
utils.log_training_args(args, logger)
Expand Down
2 changes: 2 additions & 0 deletions stanza/models/classifiers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def load(filename, args, foundation_cache=None, load_optimizer=False):
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
try:
# TODO: switch to weights_only=True
# need to convert enums to int first
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.exception("Cannot load model from {}".format(filename))
Expand Down
4 changes: 2 additions & 2 deletions stanza/models/common/char_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def from_full_state(cls, state, finetune=False):

@classmethod
def load(cls, filename, finetune=False):
state = torch.load(filename, lambda storage, loc: storage)
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
# allow saving just the Model object,
# and allow for old charlms to still work
if 'state_dict' in state:
Expand Down Expand Up @@ -342,7 +342,7 @@ def load(cls, args, filename, finetune=False):
Note that you MUST set finetune=True if planning to continue training
Otherwise the only benefit you will get will be a warm GPU
"""
state = torch.load(filename, lambda storage, loc: storage)
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
model = CharacterLanguageModel.from_full_state(state['model'], finetune)
model = model.to(args['device'])

Expand Down
1 change: 1 addition & 0 deletions stanza/models/common/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def emb(self):
def load(self):
if self.filename is not None and os.path.exists(self.filename):
try:
# TODO: update all pretrains to satisfy weights_only=True
data = torch.load(self.filename, lambda storage, loc: storage)
logger.debug("Loaded pretrain from {}".format(self.filename))
if not isinstance(data, dict):
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def save(self, filename):
torch.save(savedict, filename)

def load(self, filename):
savedict = torch.load(filename, lambda storage, loc: storage)
savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)

self.model.load_state_dict(savedict['model'])
if self.args['mode'] == 'train':
Expand Down
3 changes: 3 additions & 0 deletions stanza/models/constituency/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename)))
try:
# TODO: currently cannot switch this to weights_only=True
# without in some way changing the model to save enums in
# a safe manner, probably by converting to int
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.exception("Cannot load model from %s", filename)
Expand Down
1 change: 1 addition & 0 deletions stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def load_weights(self,
if map_location is None:
map_location = self.config.device
logger.debug(f"Loading from {path}...")
# TODO: the config is preventing us from using weights_only=True
state_dicts = torch.load(path, map_location=map_location)
self.epochs_trained = state_dicts.pop("epochs_trained", 0)
# just ignore a config in the model, since we should already have one
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/depparse/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
"""
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/langid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def load(cls, path, device=None, batch_size=64, lang_subset=None):
raise FileNotFoundError("Trying to load langid model, but path not specified! Try --load_name")
if not os.path.exists(path):
raise FileNotFoundError("Trying to load langid model from path which does not exist: %s" % path)
checkpoint = torch.load(path, map_location=torch.device("cpu"))
checkpoint = torch.load(path, map_location=torch.device("cpu"), weights_only=True)
weights = checkpoint["model_state_dict"]["loss_train.weight"]
model = cls(checkpoint["char_to_idx"], checkpoint["tag_to_idx"], checkpoint["num_layers"],
checkpoint["embedding_dim"], checkpoint["hidden_dim"], batch_size=batch_size, weights=weights,
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def save(self, filename, skip_modules=True):

def load(self, filename, args, foundation_cache):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/mwt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def save(self, filename):

def load(self, filename):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/ner/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=
def init_vocab(self, data):
def from_model(model_filename):
""" Try loading vocab from charLM model file. """
state_dict = torch.load(model_filename, lambda storage, loc: storage)
state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True)
if 'vocab' in state_dict:
return state_dict['vocab']
if 'model' in state_dict and 'vocab' in state_dict['model']:
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/ner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def save(self, filename, skip_modules=True):

def load(self, filename, pretrain=None, args=None, foundation_cache=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/pos/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def load(self, filename, pretrain, args=None, foundation_cache=None):
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
"""
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/tokenization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def save(self, filename):

def load(self, filename):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
Expand Down
2 changes: 1 addition & 1 deletion stanza/tests/common/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_resave_pretrain():
vec_filename=f'unban_mox_opal')
check_pretrain(pt2)

pt3 = torch.load(test_pt_file.name)
pt3 = torch.load(test_pt_file.name, weights_only=True)
check_embedding(pt3['emb'])
finally:
os.unlink(test_pt_file.name)
Expand Down
4 changes: 2 additions & 2 deletions stanza/tests/depparse/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
save_name = trainer.args['save_name']
filename = tmp_path / save_name
assert os.path.exists(filename)
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())

# Test loading the saved model, saving it, and still having bert in it
Expand All @@ -157,7 +157,7 @@ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
saved_model.save(filename)

# This is the part that would fail if the force_bert_saved option did not exist
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())

def test_with_peft(self, tmp_path, wordvec_pretrain_file):
Expand Down
2 changes: 1 addition & 1 deletion stanza/tests/lemma/test_lemma_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,5 @@ def test_charlm_train(self, tmp_path, charlm_args):
# check that the charlm wasn't saved in here
args = saved_model.args
save_name = os.path.join(args['save_dir'], args['save_name'])
checkpoint = torch.load(save_name, lambda storage, loc: storage)
checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)
assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys())
4 changes: 2 additions & 2 deletions stanza/tests/ner/test_ner_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_train_model_cpu(pretrain_file, tmp_path):
assert str(device).startswith("cpu")

def model_file_has_bert(filename):
checkpoint = torch.load(filename, lambda storage, loc: storage)
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
return any(x.startswith("bert_model.") for x in checkpoint['model'].keys())

def test_with_bert(pretrain_file, tmp_path):
Expand All @@ -253,7 +253,7 @@ def test_with_peft_finetune(pretrain_file, tmp_path):
# TODO: check that the peft tensors are moving when training?
trainer = run_training(pretrain_file, tmp_path, '--bert_model', 'hf-internal-testing/tiny-bert', '--use_peft')
model_file = os.path.join(trainer.args['save_dir'], trainer.args['save_name'])
checkpoint = torch.load(model_file, lambda storage, loc: storage)
checkpoint = torch.load(model_file, lambda storage, loc: storage, weights_only=True)
assert 'bert_lora' in checkpoint
assert not any(x.startswith("bert_model.") for x in checkpoint['model'].keys())

Expand Down

0 comments on commit 80be642

Please sign in to comment.