From db06bb61be31750e77d7ea8c3cd47c42651895ff Mon Sep 17 00:00:00 2001 From: Sindre Eiklid Date: Mon, 16 Oct 2023 09:09:24 +0200 Subject: [PATCH] Add method `forward` to `ModelGloVe` --- source/models.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/source/models.py b/source/models.py index 9587934..7405605 100644 --- a/source/models.py +++ b/source/models.py @@ -16,11 +16,15 @@ def __init__( device: str, vocabulary_size: int, embedding_size: int, + x_max: float, + alpha: float, padding_idx: int = None ): super().__init__() self.device = device self.filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "cbow", "model.pt") + self.x_max = x_max + self.alpha = alpha self.padding_idx = padding_idx # init embedding layers self.main_embeddings = torch.nn.Embedding( @@ -49,6 +53,35 @@ def __init__( # send model to device self.to(self.device) + def save(self): + os.makedirs(os.path.dirname(self.filepath), exist_ok=True) + torch.save(self.state_dict(), self.filepath) + + def load(self): + self.load_state_dict(torch.load(self.filepath, map_location=self.device)) + + def get_embeddings(self) -> np.ndarray: + embeddings: np.ndarray = (self.main_embeddings.weight.cpu().detach() + self.context_embeddings.weight.cpu().detach()).numpy() + # if padding is used, set it to 1 to avoid division by zero + if self.padding_idx is not None: + embeddings[self.padding_idx] = 1 + embeddings = utils.normalize(embeddings, axis=1, keepdims=True) + return embeddings + + def forward(self, word_index: torch.Tensor, context_index: torch.Tensor, cooccurrence_count: torch.Tensor): + # dot product calculation + dot_product: torch.Tensor = (self.main_embeddings(word_index) * self.context_embeddings(context_index)).sum(dim=1) + # prediction + prediction = dot_product + self.main_bias[word_index] + self.context_bias[context_index] + # weighted loss calculation + x_scaled = (cooccurrence_count / self.x_max).pow(self.alpha) + weighted_error = (x_scaled.clamp(0, 1) * (prediction - cooccurrence_count.log()) ** 2).mean() + return weighted_error + + def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float: + validation_dataloader.evaluate_analogies(self.get_embeddings(), quiet=True) + return validation_dataloader.analogies_accuracy() + class ModelCBOW(torch.nn.Module): def __init__(