From 5c09413c20fec950429b29ef9173dbeff18aa26a Mon Sep 17 00:00:00 2001 From: Sindre Eiklid Date: Sun, 15 Oct 2023 13:21:08 +0200 Subject: [PATCH] Add `validate` method to `ModelCBOW` --- source/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/models.py b/source/models.py index f48099c..bbd30d0 100644 --- a/source/models.py +++ b/source/models.py @@ -55,3 +55,7 @@ def forward(self, context_words_idx): context_vector = torch.sum(self.input_embeddings(context_words_idx), dim=1) output = torch.matmul(context_vector, self.output_embeddings.weight.t()) return torch.nn.functional.log_softmax(output, dim=1) + + def validate(self, validation_dataloader: datahandler.loaders.ValidationLoader) -> float: + validation_dataloader.evaluate_analogies(self.get_embeddings()) + return validation_dataloader.analogies_accuracy()