Skip to content

Commit

Permalink
Keep the dataset in a single object rather than a bunch of lists. Mak…
Browse files Browse the repository at this point in the history
…es it easier to shuffle, keeps everything in one place
  • Loading branch information
AngledLuffa committed Jan 16, 2024
1 parent 1281dd2 commit b84425f
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 103 deletions.
10 changes: 5 additions & 5 deletions stanza/models/lemma_classifier/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
if not train_file:
raise ValueError("Cannot train model - no train_file supplied!")

text_batches, position_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
label_decoder = dataset.label_decoder
upos_to_id = dataset.upos_to_id
self.output_dim = len(label_decoder)
logging.info(f"Loaded dataset successfully from {train_file}")
logging.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")

assert len(text_batches) == len(position_batches) == len(label_batches), f"Input batch sizes did not match ({len(text_batches)}, {len(position_batches)}, {len(label_batches)})."

self.model = self.build_model(label_decoder, upos_to_id)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

Expand All @@ -70,7 +70,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")

if self.weighted_loss:
self.configure_weighted_loss(label_decoder, counts)
self.configure_weighted_loss(label_decoder, dataset.counts)

# Put the criterion on GPU too
logging.debug(f"Criterion on {next(self.model.parameters()).device}")
Expand All @@ -79,7 +79,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for sentences, positions, upos_tags, labels in tqdm(zip(text_batches, position_batches, upos_batches, label_batches), total=len(text_batches)):
for sentences, positions, upos_tags, labels in tqdm(dataset):
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"

self.optimizer.zero_grad()
Expand Down
23 changes: 10 additions & 13 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: M
return weighted_f1 / num_total_examples


def evaluate_sequences(gold_tag_sequences: List[List[Any]], pred_tag_sequences: List[List[Any]], label_decoder: Mapping, verbose=True):
def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):
"""
Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.
Expand All @@ -69,11 +69,8 @@ def evaluate_sequences(gold_tag_sequences: List[List[Any]], pred_tag_sequences:
confusion = defaultdict(lambda: defaultdict(int))

reverse_label_decoder = {y: x for x, y in label_decoder.items()}
for gold_tags, pred_tags in tqdm(zip(gold_tag_sequences, pred_tag_sequences), "Evaluating sequences", total=len(gold_tag_sequences)):

assert len(gold_tags) == len(pred_tags), f"Number of gold tags doesn't match number of predicted tags ({len(gold_tags)}, {len(pred_tags)})"
for gold, pred in zip(gold_tags, pred_tags):
confusion[reverse_label_decoder[gold.item()]][reverse_label_decoder[pred]] += 1
for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):
confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1

multi_class_result = defaultdict(lambda: defaultdict(float))
# compute precision, recall and f1 for each class and store inside of `multi_class_result`
Expand Down Expand Up @@ -154,29 +151,29 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr
model.eval() # set to eval mode

# load in eval data
text_batches, index_batches, upos_batches, label_batches, _, label_decoder, upos_to_id = utils.load_dataset(eval_path, label_decoder=model.label_decoder)
dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)

logging.info(f"Evaluating on evaluation file {eval_path}")

correct, total = 0, 0
gold_tags, pred_tags = label_batches, []
gold_tags, pred_tags = dataset.labels, []

# run eval on each example from dataset
for sentences, pos_indices, upos_tags, labels in tqdm(zip(text_batches, index_batches, upos_batches, label_batches), "Evaluating examples from data file", total=len(text_batches)):
for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"):
pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, )
correct_preds = pred == labels.to(device)
correct += torch.sum(correct_preds)
total += len(correct_preds)
pred_tags += [pred.tolist()]
pred_tags += pred.tolist()

logging.info("Finished evaluating on dataset. Computing scores...")
accuracy = correct / total

mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, label_decoder, verbose=verbose)
mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)
# add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
if verbose:
logging.info(f"Accuracy: {accuracy} ({correct}/{total})")
logging.info(f"Label decoder: {label_decoder}")
logging.info(f"Label decoder: {dataset.label_decoder}")

return mc_results, confusion, accuracy, weighted_f1

Expand Down
168 changes: 97 additions & 71 deletions stanza/models/lemma_classifier/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter, defaultdict
import logging
import os
import random
from typing import List, Tuple, Any, Mapping

import stanza
Expand All @@ -9,77 +10,102 @@
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
import stanza.models.lemma_classifier.prepare_dataset as prepare_dataset


def load_dataset(data_path: str, batch_size=DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None) -> Tuple[List[List[str]], List[torch.Tensor], List[torch.Tensor], Mapping[int, int], Mapping[str, int]]:
"""
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
Args:
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
batch_size (int): Size of each batch of examples
get_counts (optional, bool): Whether there should be a map of the label index to counts
Returns:
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
5 (Optional): A mapping of label ID to counts in the dataset.
6. Mapping[str, int]: A map between the labels and their indexes
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
"""

if data_path is None or not os.path.exists(data_path):
raise FileNotFoundError(f"Data file {data_path} could not be found.")

if label_decoder is None:
label_decoder = {}
else:
# if labels in the test set aren't in the original model,
# the model will never predict those labels,
# but we can still use those labels in a confusion matrix
label_decoder = dict(label_decoder)

logging.debug("Final label decoder: %s Should be strings to ints", label_decoder)

with open(data_path, "r+", encoding="utf-8") as f:
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), defaultdict(str)

data_processor = prepare_dataset.DataProcessor("", [], "")
sentences_data = data_processor.read_processed_data(data_path)

for idx, sentence in enumerate(sentences_data):
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
if None in [words, target_idx, upos_tags, label]:
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")

label_id = label_decoder.get(label, None)
if label_id is None:
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label

converted_upos_tags = [] # convert upos tags to upos IDs
for upos_tag in upos_tags:
upos_id = upos_to_id.get(upos_tag, None)
if upos_id is None:
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
converted_upos_tags.append(upos_to_id[upos_tag])

sentences.append(words)
indices.append(target_idx)
upos_ids.append(converted_upos_tags)
labels.append(label_decoder[label])

if get_counts:
counts[label_decoder[label]] += 1

sentence_batches = [sentences[i: i + batch_size] for i in range(0, len(sentences), batch_size)]
indices_batches = [torch.tensor(indices[i: i + batch_size]) for i in range(0, len(indices), batch_size)]
upos_batches = [upos_ids[i: i + batch_size] for i in range(0, len(upos_ids), batch_size)]
labels_batches = [torch.tensor(labels[i: i + batch_size]) for i in range(0, len(labels), batch_size)]
# TODO consider making the return object a JSON or a custom object for cleaner access instead of a big tuple of stuff
return sentence_batches, indices_batches, upos_batches, labels_batches, counts, label_decoder, upos_to_id

class Dataset:
def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
"""
Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
Args:
data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
batch_size (int): Size of each batch of examples
get_counts (optional, bool): Whether there should be a map of the label index to counts
Returns:
1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
5 (Optional): A mapping of label ID to counts in the dataset.
6. Mapping[str, int]: A map between the labels and their indexes
7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
"""

if data_path is None or not os.path.exists(data_path):
raise FileNotFoundError(f"Data file {data_path} could not be found.")

if label_decoder is None:
label_decoder = {}
else:
# if labels in the test set aren't in the original model,
# the model will never predict those labels,
# but we can still use those labels in a confusion matrix
label_decoder = dict(label_decoder)

logging.debug("Final label decoder: %s Should be strings to ints", label_decoder)

with open(data_path, "r+", encoding="utf-8") as f:
sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), defaultdict(str)

data_processor = prepare_dataset.DataProcessor("", [], "")
sentences_data = data_processor.read_processed_data(data_path)

for idx, sentence in enumerate(sentences_data):
# TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
if None in [words, target_idx, upos_tags, label]:
raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")

label_id = label_decoder.get(label, None)
if label_id is None:
label_decoder[label] = len(label_decoder) # create a new ID for the unknown label

converted_upos_tags = [] # convert upos tags to upos IDs
for upos_tag in upos_tags:
upos_id = upos_to_id.get(upos_tag, None)
if upos_id is None:
upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
converted_upos_tags.append(upos_to_id[upos_tag])

sentences.append(words)
indices.append(target_idx)
upos_ids.append(converted_upos_tags)
labels.append(label_decoder[label])

if get_counts:
counts[label_decoder[label]] += 1

self.sentences = sentences
self.indices = indices
self.upos_ids = upos_ids
self.labels = labels

self.counts = counts
self.label_decoder = label_decoder
self.upos_to_id = upos_to_id

self.batch_size = batch_size
self.shuffle = shuffle

def __len__(self):
"""
Number of batches, rounded up to nearest batch
"""
return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)

def __iter__(self):
num_sentences = len(self.sentences)
indices = list(range(num_sentences))
if self.shuffle:
random.shuffle(indices)
for i in range(self.__len__()):
batch_start = self.batch_size * i
batch_end = min(batch_start + self.batch_size, num_sentences)

batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
yield batch_sentences, batch_indices, batch_upos_ids, batch_labels

def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
"""
Expand Down
39 changes: 25 additions & 14 deletions stanza/tests/lemma_classifier/test_data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,37 @@ def test_convert_one_sentence(tmp_path):
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"])
assert len(converted_files) == 1

text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(converted_files[0], get_counts=True, batch_size=10)
assert text_batches == [[['Here', "'s", 'a', 'Miami', 'Herald', 'interview']]]
assert label_decoder == {'be': 0}
id_to_upos = {y: x for x, y in upos_to_id.items()}
upos = [id_to_upos[x] for x in upos_batches[0][0]]
assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)

assert len(dataset) == 1
assert dataset.label_decoder == {'be': 0}
id_to_upos = {y: x for x, y in dataset.upos_to_id.items()}

for text_batches, _, upos_batches, _ in dataset:
assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']]
upos = [id_to_upos[x] for x in upos_batches[0]]
assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']

def test_convert_dataset(tmp_path):
converted_files = convert_english_dataset(tmp_path)

text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(converted_files[0], get_counts=True, batch_size=10)
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)

assert len(text_batches[0]) == 9
assert len(dataset) == 1
label_decoder = dataset.label_decoder
assert len(label_decoder) == 2
assert "be" in label_decoder
assert "have" in label_decoder

text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(converted_files[1], get_counts=True, batch_size=10)
assert len(text_batches[0]) == 2

text_batches, idx_batches, upos_batches, label_batches, counts, label_decoder, upos_to_id = utils.load_dataset(converted_files[2], get_counts=True, batch_size=10)
assert len(text_batches[0]) == 2
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 9

dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 2

dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False)
assert len(dataset) == 1
for text_batches, _, _, _ in dataset:
assert len(text_batches) == 2

0 comments on commit b84425f

Please sign in to comment.