diff --git a/source/models.py b/source/models.py index f48099c..bbd30d0 100644 --- a/source/models.py +++ b/source/models.py @@ -55,3 +55,7 @@ def forward(self, context_words_idx): context_vector = torch.sum(self.input_embeddings(context_words_idx), dim=1) output = torch.matmul(context_vector, self.output_embeddings.weight.t()) return torch.nn.functional.log_softmax(output, dim=1) + + def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float: + validation_dataloader.evaluate_analogies(self.get_embeddings()) + return validation_dataloader.analogies_accuracy()