Skip to content

Commit

Permalink
Add method forward to ModelGloVe
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 16, 2023
1 parent 679e668 commit db06bb6
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions source/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ def __init__(
device: str,
vocabulary_size: int,
embedding_size: int,
x_max: float,
alpha: float,
padding_idx: int = None
):
super().__init__()
self.device = device
self.filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "cbow", "model.pt")
self.x_max = x_max
self.alpha = alpha
self.padding_idx = padding_idx
# init embedding layers
self.main_embeddings = torch.nn.Embedding(
Expand Down Expand Up @@ -49,6 +53,35 @@ def __init__(
# send model to device
self.to(self.device)

def save(self):
os.makedirs(os.path.dirname(self.filepath), exist_ok=True)
torch.save(self.state_dict(), self.filepath)

def load(self):
self.load_state_dict(torch.load(self.filepath, map_location=self.device))

def get_embeddings(self) -> np.ndarray:
embeddings: np.ndarray = (self.main_embeddings.weight.cpu().detach() + self.context_embeddings.weight.cpu().detach()).numpy()
# if padding is used, set it to 1 to avoid division by zero
if self.padding_idx is not None:
embeddings[self.padding_idx] = 1
embeddings = utils.normalize(embeddings, axis=1, keepdims=True)
return embeddings

def forward(self, word_index: torch.Tensor, context_index: torch.Tensor, cooccurrence_count: torch.Tensor):
# dot product calculation
dot_product: torch.Tensor = (self.main_embeddings(word_index) * self.context_embeddings(context_index)).sum(dim=1)
# prediction
prediction = dot_product + self.main_bias[word_index] + self.context_bias[context_index]
# weighted loss calculation
x_scaled = (cooccurrence_count / self.x_max).pow(self.alpha)
weighted_error = (x_scaled.clamp(0, 1) * (prediction - cooccurrence_count.log()) ** 2).mean()
return weighted_error

def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float:
validation_dataloader.evaluate_analogies(self.get_embeddings(), quiet=True)
return validation_dataloader.analogies_accuracy()


class ModelCBOW(torch.nn.Module):
def __init__(
Expand Down

0 comments on commit db06bb6

Please sign in to comment.