Skip to content

Commit

Permalink
Unify args somewhat between the training versions - make documentatio…
Browse files Browse the repository at this point in the history
…n more accurate, use args['batch_size'] directly instead of passing it in
  • Loading branch information
AngledLuffa committed Jan 15, 2024
1 parent 865d05d commit 91843d7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 22 deletions.
21 changes: 8 additions & 13 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,24 @@ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
self.criterion = nn.BCEWithLogitsLoss(weight=weights)

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs) -> None:

"""
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
Args:
texts_batch (List[List[str]]): Batches of tokenized texts, one per sentence. Expected to contain at least one instance of the target token.
positions_batch (List[int]): Batches of position indices (zero-indexed) for the target token, one per input sentence.
labels_batch (List[int]): Batches of labels for the target token, one per input sentence.
num_epochs (int): Number of training epochs
save_name (str): Path to file where trained model should be saved.
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
Kwargs:
train_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
batch_size (int): Number of examples to include in each batch.
"""

device = default_device() # Put model on GPU (if possible)
# Put model on GPU (if possible)
device = default_device()

train_path = kwargs.get("train_path")
upos_to_id = {}
if train_path: # use file to train model
text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(train_path, get_counts=self.weighted_loss, batch_size=kwargs.get("batch_size", DEFAULT_BATCH_SIZE))
text_batches, position_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(train_path, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
self.output_dim = len(label_decoder)
logging.info(f"Loaded dataset successfully from {train_path}")
logging.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
Expand All @@ -136,7 +132,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
self.model.to(device)
logging.info(f"Device chosen: {device}. {next(self.model.parameters()).device}")

assert len(text_batches) == len(idx_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(idx_batches)}, {len(label_batches)})."
assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."
if path.exists(save_name):
raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")

Expand All @@ -151,7 +147,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
logging.info("Embedding norm: %s", torch.linalg.norm(self.model.embedding.weight))
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for texts, positions, upos_tags, labels in tqdm(zip(text_batches, idx_batches, upos_batches, label_batches), total=len(text_batches)):
for texts, positions, upos_tags, labels in tqdm(zip(text_batches, position_batches, upos_batches, label_batches), total=len(text_batches)):

self.optimizer.zero_grad()
output = self.model(positions, texts, upos_tags)
Expand Down Expand Up @@ -197,7 +193,7 @@ def build_argparse():
parser.add_argument("--save_name", type=str, default=path.join(path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Number of examples to include in each batch")
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
Expand All @@ -217,7 +213,6 @@ def main(args=None):
num_heads = args.num_heads
save_name = args.save_name
lr = args.lr
batch_size = args.batch_size
num_epochs = args.num_epochs
train_file = args.train_file
weighted_loss = args.weighted_loss
Expand Down Expand Up @@ -247,7 +242,7 @@ def main(args=None):
)

trainer.train(
num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_path=train_file, batch_size=batch_size
num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_path=train_file
)

return trainer
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
"""

import torch.nn as nn
Expand All @@ -12,7 +12,7 @@
import logging

from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.constants import ModelType
from stanza.models.lemma_classifier.constants import ModelType, DEFAULT_BATCH_SIZE
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
from stanza.models.lemma_classifier.transformer_baseline.model import LemmaClassifierWithTransformer
from stanza.utils.get_tqdm import get_tqdm
Expand Down Expand Up @@ -85,27 +85,23 @@ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torc
])
return optimizer

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs):

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, **kwargs) -> None:
"""
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
Args:
texts_batch (List[List[str]]): Batches of tokenized texts, one per sentence. Expected to contain at least one instance of the target token.
positions_batch (List[int]): Batches of position indices (zero-indexed) for the target token, one per input sentence.
labels_batch (List[int]): Batches of labels for the target token, one per input sentence.
num_epochs (int): Number of training epochs
save_name (str): Path to file where trained model should be saved.
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
Kwargs:
train_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
"""
# Put model on GPU (if possible)
device = default_device()

if kwargs.get("train_path"):
text_batches, position_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss)
text_batches, position_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(kwargs.get("train_path"), get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
self.output_dim = len(label_decoder)
logging.info(f"Using label decoder : {label_decoder}")

Expand Down Expand Up @@ -169,6 +165,7 @@ def main(args=None):
parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.")

Expand Down

0 comments on commit 91843d7

Please sign in to comment.