Skip to content

Commit

Permalink
Remove a couple arguments which can be inferred from the embedding fi…
Browse files Browse the repository at this point in the history
…le. Otherwise turn various kwargs into named arguments
  • Loading branch information
AngledLuffa committed Dec 28, 2023
1 parent 91cb24f commit b5e9f8d
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ class LemmaClassifierTrainer():
Class to assist with training a LemmaClassifierLSTM
"""

def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hidden_dim: int, output_dim: int = 2, use_charlm: bool = False, eval_file: str = None, **kwargs):
def __init__(self, embedding_file: str, hidden_dim: int, output_dim: int = 2, use_charlm: bool = False, forward_charlm_file: str = None, backward_charlm_file: str = None, lr: float = 0.001, loss_func: str = None, eval_file: str = None):
"""
Initializes the LemmaClassifierTrainer class.
Args:
vocab_size (int): Size of the vocab being used (if custom vocab)
embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt
embedding_dim (int): Size of embedding dimension to use on the aforementioned word embeddings
hidden_dim (int): Size of hidden vectors in LSTM layers
output_dim (int, optional): Size of output vector from MLP layer. Defaults to 2.
use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
Expand Down Expand Up @@ -66,31 +64,27 @@ def __init__(self, vocab_size: int, embedding_file: str, embedding_dim: int, hid
self.embedding_dim = emb_matrix.shape[1]

# Load CharLM embeddings
forward_charlm_file = kwargs.get("forward_charlm_file")
backward_charlm_file = kwargs.get("backward_charlm_file")
if use_charlm and forward_charlm_file is not None and not os.path.exists(forward_charlm_file):
raise FileNotFoundError(f"Could not find forward charlm file: {forward_charlm_file}")
if use_charlm and backward_charlm_file is not None and not os.path.exists(backward_charlm_file):
raise FileNotFoundError(f"Could not find backward charlm file: {backward_charlm_file}")

# TODO: embedding_dim and vocab_size are read off the embeddings file
self.model = LemmaClassifierLSTM(self.vocab_size, self.embedding_dim, hidden_dim, output_dim, self.vocab_map, self.embeddings, charlm=use_charlm,
charlm_forward_file=forward_charlm_file, charlm_backward_file=backward_charlm_file)

# Find loss function
loss_fn = kwargs.get("loss_func", "ce").lower()
if loss_fn == "ce":
if loss_func == "ce":
self.criterion = nn.CrossEntropyLoss()
self.weighted_loss = False
logging.debug("Using CE loss")
elif loss_fn == "weighted_bce":
elif loss_func == "weighted_bce":
self.criterion = nn.BCEWithLogitsLoss()
self.weighted_loss = True # used to add weights during train time.
logging.debug("Using Weighted BCE loss")
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")

self.optimizer = optim.Adam(self.model.parameters(), lr=kwargs.get("lr", 0.001))
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

def save_checkpoint(self, save_name: str, state_dict: Mapping, label_decoder: Mapping, args: Mapping) -> Mapping:
"""
Expand Down Expand Up @@ -203,8 +197,6 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,

def build_argparse():
parser = argparse.ArgumentParser()
parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
parser.add_argument("--output_dim", type=int, default=2, help="Size of output layer (number of classes)")
parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
Expand All @@ -224,8 +216,6 @@ def main(args=None):
parser = build_argparse()
args = parser.parse_args(args)

vocab_size = args.vocab_size
embedding_dim = args.embedding_dim
hidden_dim = args.hidden_dim
output_dim = args.output_dim
wordvec_pretrain_file = args.wordvec_pretrain_file
Expand All @@ -251,9 +241,7 @@ def main(args=None):
logging.info(f"{arg}: {args[arg]}")
logging.info("------------------------------------------------------------")

trainer = LemmaClassifierTrainer(vocab_size=vocab_size,
embedding_file=wordvec_pretrain_file,
embedding_dim=embedding_dim,
trainer = LemmaClassifierTrainer(embedding_file=wordvec_pretrain_file,
hidden_dim=hidden_dim,
output_dim=output_dim,
use_charlm=use_charlm,
Expand Down

0 comments on commit b5e9f8d

Please sign in to comment.