From 0f429a72d0a4e592951a7dec2962500fee5760ed Mon Sep 17 00:00:00 2001 From: Sindre Eiklid Date: Mon, 16 Oct 2023 00:53:48 +0200 Subject: [PATCH] Add `fit` method to `ModelSkipGram` class --- source/architechtures/skipgram.py | 20 ++++++ source/config_skipgram.yml | 4 +- source/models.py | 111 +++++++++++++++++++++++++++++- 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/source/architechtures/skipgram.py b/source/architechtures/skipgram.py index d76df3f..cb51f61 100644 --- a/source/architechtures/skipgram.py +++ b/source/architechtures/skipgram.py @@ -33,3 +33,23 @@ def run() -> None: # get validation data validation_dataloader = datahandler.loaders.ValidationLoader(data_directory="skipgram") validation_dataloader.build(vocabulary) + # fit model and get embeddings + model = models.ModelSkipGram(config.device, config.vocabulary_size, config.embedding_size, vocabulary.padding_index) + model.fit( + training_dataloader, + validation_dataloader, + config.learning_rate, + config.max_epochs, + config.min_loss_improvement, + config.patience, + config.validation_interval + ) + embeddings = model.get_embeddings() + # evaluate embeddings + validation_dataloader.evaluate_analogies(embeddings) + validation_dataloader.evaluate_word_pair_similarity(embeddings) + utils.print_divider() + validation_dataloader.plot_analogies_rank(k=20) + validation_dataloader.plot_word_pair_similarity() + print(f"Analogy accuracy: {(validation_dataloader.analogies_accuracy() * 100):.2f}%") + print(f"Spearman correlation coefficient: {validation_dataloader.word_pair_spearman_correlation():.5f}") diff --git a/source/config_skipgram.yml b/source/config_skipgram.yml index 041f405..86ee0a4 100644 --- a/source/config_skipgram.yml +++ b/source/config_skipgram.yml @@ -4,9 +4,9 @@ device: "cuda:0" vocabulary_size: 10000 batch_size: 128 window_size: 2 -embedding_size: 200 +embedding_size: 100 -learning_rate: 0.05 +learning_rate: 0.001 max_epochs: 20 min_loss_improvement: 0.00001 patience: 10 diff --git a/source/models.py b/source/models.py index b281ecc..63285c4 100644 --- a/source/models.py +++ b/source/models.py @@ -111,10 +111,10 @@ def fit( optimizer.zero_grad() # unpack batch data and send to device context_words: torch.Tensor = batch[0] - target_word: torch.Tensor = batch[1] + target_words: torch.Tensor = batch[1] # compute gradients outputs: torch.Tensor = self(context_words) - loss: torch.Tensor = criterion(outputs, target_word) + loss: torch.Tensor = criterion(outputs, target_words) total_loss += loss.item() # apply gradients loss.backward() @@ -191,3 +191,110 @@ def __init__( self.output_embeddings.weight.data[self.padding_idx, :] = 0 # send model to device self.to(self.device) + + def save(self): + os.makedirs(os.path.dirname(self.filepath), exist_ok=True) + torch.save(self.state_dict(), self.filepath) + + def load(self): + self.load_state_dict(torch.load(self.filepath, map_location=self.device)) + + def get_embeddings(self) -> np.ndarray: + embeddings: np.ndarray = self.input_embeddings.weight.cpu().detach().numpy() + # if padding is used, set it to 1 to avoid division by zero + if self.padding_idx is not None: + embeddings[self.padding_idx] = 1 + embeddings = utils.normalize(embeddings, axis=1, keepdims=True) + return embeddings + + def forward(self, target_words): + target_vector = self.input_embeddings(target_words) + output = torch.matmul(target_vector, self.output_embeddings.weight.t()) + return torch.nn.functional.log_softmax(output, dim=1) + + def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float: + validation_dataloader.evaluate_analogies(self.get_embeddings(), quiet=True) + return validation_dataloader.analogies_accuracy() + + def fit( + self, + training_dataloader: datahandler.loaders.DataLoaderCBOW, + validation_dataloader: datahandler.loaders.ValidationLoader, + learning_rate: float, + max_epochs: int, + min_loss_improvement: float, + patience: int, + validation_interval: int + ): + # check if cache exists + if os.path.exists(self.filepath): + progress_bar = tqdm.tqdm(desc="Loading cached model", total=1) + self.load() + progress_bar.update(1) + return + print("Training model:") + loss_history = [] + acc_history = [] + dataset_size = len(training_dataloader) + last_batch_index = dataset_size - 1 + # set optimizer and critirion + optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate) + criterion = torch.nn.NLLLoss() + # loop through each epoch + best_loss = float("inf") + best_acc = -float("inf") + epochs_without_improvement = 0 + for epoch in range(max_epochs): + total_loss = 0.0 + # define the progressbar + progressbar = utils.get_model_progressbar(training_dataloader, epoch, max_epochs) + # set model to training mode + self.train() + # loop through the dataset + for idx, batch in enumerate(progressbar): + # clear gradients + optimizer.zero_grad() + # unpack batch data and send to device + context_words: torch.Tensor = batch[0] + target_words: torch.Tensor = batch[1] + # compute gradients + outputs: torch.Tensor = self(target_words) + loss: torch.Tensor = criterion(outputs, context_words) + total_loss += loss.item() + # apply gradients + loss.backward() + optimizer.step() + # branch if on last iteration + if idx == last_batch_index: + # update early stopping and save model + train_loss = total_loss / (idx + 1) + if train_loss <= (best_loss - min_loss_improvement): + best_loss = train_loss + epochs_without_improvement = 0 + # save best model + self.save() + else: + epochs_without_improvement += 1 + # validate model every n epochs + if epoch == 0 or (epoch + 1) == max_epochs or (epoch + 1) % validation_interval == 0: + self.eval() + train_acc = self.validate(validation_dataloader) + if train_acc >= best_acc: + best_acc = train_acc + self.train() + else: + train_acc = acc_history[-1] + # add to history and plot + loss_history.append(train_loss) + acc_history.append(train_acc) + utils.plot_loss_and_accuracy(loss_history, acc_history, data_directory="skipgram") + # update information with current values + utils.set_model_progressbar_prefix(progressbar, train_loss, best_loss, train_acc, best_acc) + # check for early stopping + if epochs_without_improvement >= patience: + break + # load the best model from training + self.load() + # empty GPU cache + if "cuda" in self.device: + torch.cuda.empty_cache()