Skip to content

Commit

Permalink
Now the forward pass looks exactly the same between the two versions …
Browse files Browse the repository at this point in the history
…of the model
  • Loading branch information
AngledLuffa committed Jan 15, 2024
1 parent cc20b74 commit 5139b61
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
6 changes: 1 addition & 5 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import stanza

from stanza.models.common.utils import default_device
from stanza.models.lemma_classifier.constants import ModelType
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.base_model import LemmaClassifier
from stanza.models.lemma_classifier.model import LemmaClassifierLSTM
Expand Down Expand Up @@ -123,10 +122,7 @@ def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: L
(int): The index of the predicted class in `model`'s output.
"""
with torch.no_grad():
if model.model_type() == ModelType.LSTM:
logits = model(position_indices, sentences, upos_tags)
else:
logits = model(position_indices, sentences) # should be size (batch_size, output_size)
logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)

return predicted_class
Expand Down
5 changes: 3 additions & 2 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,11 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for texts, positions, upos_tags, labels in tqdm(zip(text_batches, position_batches, upos_batches, label_batches), total=len(text_batches)):
for sentences, positions, upos_tags, labels in tqdm(zip(text_batches, position_batches, upos_batches, label_batches), total=len(text_batches)):
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"

self.optimizer.zero_grad()
outputs = self.model(positions, texts, upos_tags)
outputs = self.model(positions, sentences, upos_tags)

# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
logging.debug(f"Criterion on {next(self.model.parameters()).device}")
self.criterion = self.criterion.to(next(self.model.parameters()).device)

best_model, best_f1 = None, float("-inf")
best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for sentences, positions, labels in tqdm(zip(text_batches, position_batches, label_batches), total=len(text_batches)):
for sentences, positions, upos_tags, labels in tqdm(zip(text_batches, position_batches, upos_batches, label_batches), total=len(text_batches)):
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"

self.optimizer.zero_grad()
outputs = self.model(positions, sentences)
outputs = self.model(positions, sentences, upos_tags)

# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
Expand Down
11 changes: 7 additions & 4 deletions stanza/models/lemma_classifier/transformer_baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ def get_save_dict(self, args):
}
return save_dict

def forward(self, idx_positions: List[int], sentences: List[List[str]]):
def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
"""
Computes the forward pass of the transformer baselines
Args:
text (List[str]): A single sentence with each token as an entry in the list.
pos_index (int): The index of the token to classify on.
idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
sentences (List[List[str]]): A list of the token-split sentences of the input data.
upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility
Returns the logits of the MLP
Returns:
torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
"""
device = next(self.transformer.parameters()).device
bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
Expand Down

0 comments on commit 5139b61

Please sign in to comment.