Skip to content

Commit

Permalink
Add validate method to ModelCBOW
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 15, 2023
1 parent c2255bd commit 5c09413
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions source/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ def forward(self, context_words_idx):
context_vector = torch.sum(self.input_embeddings(context_words_idx), dim=1)
output = torch.matmul(context_vector, self.output_embeddings.weight.t())
return torch.nn.functional.log_softmax(output, dim=1)

def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float:
validation_dataloader.evaluate_analogies(self.get_embeddings())
return validation_dataloader.analogies_accuracy()

0 comments on commit 5c09413

Please sign in to comment.