Skip to content

Commit

Permalink
Edit to make batch processing work and get rid of debug statements
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Jan 5, 2024
1 parent a544c17 commit 492c650
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 26 deletions.
20 changes: 9 additions & 11 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def evaluate_sequences(gold_tag_sequences: List[List[Any]], pred_tag_sequences:
return multi_class_result, confusion, weighted_f1


def model_predict(model: nn.Module, position_indices: torch.tensor[int], sentences: List[List[str]]) -> torch.tensor[int]:
def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]]) -> torch.Tensor:
"""
A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.
Expand All @@ -121,9 +121,7 @@ def model_predict(model: nn.Module, position_indices: torch.tensor[int], sentenc
"""
with torch.no_grad():
logits = model(position_indices, sentences) # should be size (batch_size, output_size)
logging.info(f"Logits shape: {logits.shape} (should be size (batch_size, output_size))")
predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
logging.info(f"Predicted class shape: {predicted_class.shape}, (should be size (batch_size, 1))")

return predicted_class

Expand Down Expand Up @@ -155,9 +153,13 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr

# load in eval data
text_batches, index_batches, label_batches, _, label_decoder = utils.load_dataset(eval_path, label_decoder=model.label_decoder)


# TODO fix this in the future
text_batches, index_batches, label_batches = text_batches[: -1], index_batches[: -1], label_batches[: -1]

index_batches = torch.tensor(index_batches, device=device)
label_batches = torch.tensor(label_batches, device=device)
index_batches = torch.stack(index_batches).to(device)
label_batches = torch.stack(label_batches).to(device)

logging.info(f"Evaluating on evaluation file {eval_path}")

Expand All @@ -168,21 +170,17 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr
for sentences, pos_indices, labels in tqdm(zip(text_batches, index_batches, label_batches), "Evaluating examples from data file", total=len(text_batches)):
pred = model_predict(model, pos_indices, sentences) # Pred should be size (batch_size, )
correct_preds = pred == labels
logging.info(f"Correct preds shape: {correct_preds.shape} (should be size (batch_size, 1))")
correct += torch.sum(correct_preds)
total += len(correct_preds)
pred_tags += pred.tolist()
pred_tags += [pred.tolist()]

logging.info("Finished evaluating on dataset. Computing scores...")
accuracy = correct / total

logging.info(f"Gold Tags: {gold_tags}")
logging.info(f"Pred Tags: {pred_tags}")

mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, verbose=verbose)
# add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
if verbose:
logging.info(f"Accuracy: {accuracy} ({correct}/{len(label_batches)})")
logging.info(f"Accuracy: {accuracy} ({correct}/{total})")

return mc_results, confusion, accuracy, weighted_f1

Expand Down
16 changes: 7 additions & 9 deletions stanza/models/lemma_classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,32 @@ def forward(self, pos_indices: List[int], sentences: List[List[str]]):
sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words]
sentence_token_ids = torch.tensor(sentence_token_ids, device=next(self.parameters()).device)
token_ids.append(sentence_token_ids)

embedded = self.embedding(torch.tensor(token_ids))

token_ids = pad_sequence(token_ids, batch_first=True)
embedded = self.embedding(token_ids)

if self.use_charlm:
char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]]
char_reps_backward = self.charmodel_backward.build_char_representation(sentences)

embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 1)
char_reps_forward = pad_sequence(char_reps_forward, batch_first=True)
char_reps_backward = pad_sequence(char_reps_backward, batch_first=True)

embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2)

print(f"Embedding shape: {embedded.shape}. Should be size (batch_size, T, input_size)") # Should be size (batch_size, T, input_size)
padded_sequences = pad_sequence(embedded, batch_first=True)
lengths = torch.tensor([len(seq) for seq in embedded])

packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True)

print(f"Packed Sequences shape: {packed_sequences.shape}. Should be size (batch_size, input_size)") # should be size (batch_size, input_size)

lstm_out, (hidden, _) = self.lstm(packed_sequences)

# Extract the hidden state at the index of the token to classify
unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True)
lstm_out = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices]

print(f"LSTM OUT Shape: {lstm_out.shape}. Should be size (batch_size, input_size)") # Should be size (batch_size, input_size)

# MLP forward pass
output = self.mlp(lstm_out)
print(f"Output shape: {output.shape}. Should be size (batch_size, output_size)") # should be size (batch_size, output_size)
return output

def model_type(self):
Expand Down
9 changes: 4 additions & 5 deletions stanza/models/lemma_classifier/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
logging.info(f"Loaded dataset successfully from {train_path}")
logging.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")

idx_batches, label_batches = torch.tensor(idx_batches, device=device), torch.tensor(label_batches, device=device)
logging.info(f"idx batches size: {idx_batches.shape}. label_batches shape {label_batches.shape}")
text_batches, idx_batches, label_batches = text_batches[:-1], idx_batches[:-1], label_batches[:-1] # TODO come up with a fix for this

idx_batches, label_batches = torch.stack(idx_batches).to(device), torch.stack(label_batches).to(device)

self.model = LemmaClassifierLSTM(self.vocab_size, self.embedding_dim, self.hidden_dim, self.output_dim, self.vocab_map, self.embeddings, label_decoder,
charlm=self.use_charlm, charlm_forward_file=self.forward_charlm_file, charlm_backward_file=self.backward_charlm_file)
Expand Down Expand Up @@ -157,14 +158,12 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
# TODO: three classes?
targets = torch.tensor([torch.tensor([1, 0] if label == 0 else [0, 1]) for label in labels], dtype=torch.float32, device=device)
targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
# should be shape size (batch_size, 2)

else: # CELoss accepts target as just raw label
targets = labels

logging.info(f"targets shape {targets.shape}. Should be shape (batch_size, ) or (batch_size, output_dim)") # should be shape (batch_size, ) or (batch_size, 2)

loss = self.criterion(output, targets)

loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion stanza/models/lemma_classifier/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def load_doc_from_conll_file(path: str):
return stanza.utils.conll.CoNLL.conll2doc(path)


def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None) -> Tuple[List[List[str]], List[torch.Tensor[int]], List[torch.Tensor[int]], Mapping[int, int], Mapping[str, int]]:
def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None) -> Tuple[List[List[str]], List[torch.Tensor], List[torch.Tensor], Mapping[int, int], Mapping[str, int]]:

"""
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
Expand Down

0 comments on commit 492c650

Please sign in to comment.