Skip to content

Commit

Permalink
gradient checkpoining
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Jan 16, 2024
1 parent 145eedf commit 6be26ba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions stanza/models/ner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device
self.vocab = vocab
self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb, foundation_cache=foundation_cache)

if self.args.get("gradient_checkpointing", False) and self.args.get("bert_finetune", False):
self.bert_model.gradient_checkpointing_enable()

# if this wasn't set anywhere, we use a default of the 0th tagset
# we don't set this as a default in the options so that
# we can distinguish "intentionally set to 0" and "not set at all"
Expand Down
1 change: 1 addition & 0 deletions stanza/models/ner_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def build_argparse():
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps')
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')

Expand Down

0 comments on commit 6be26ba

Please sign in to comment.