From 713fb4ebbf703c5086c898920d136506b339939b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Wed, 13 Nov 2024 17:31:29 +0000 Subject: [PATCH] all functional --- rankers/modelling/sparse.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/rankers/modelling/sparse.py b/rankers/modelling/sparse.py index 27f46f4..c95eb72 100644 --- a/rankers/modelling/sparse.py +++ b/rankers/modelling/sparse.py @@ -21,6 +21,26 @@ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, confi from .pyterrier.sparse import SparseTransformer self.transformer_class = SparseTransformer + + def forward(self, + loss = None, + queries = None, + docs_batch = None, + labels=None): + """Compute the loss given (queries, docs, labels)""" + queries = {k: v.to(self.model.device) for k, v in queries.items()} if queries is not None else None + docs_batch = {k: v.to(self.model_d.device) for k, v in docs_batch.items()} if docs_batch is not None else None + labels = labels.to(self.model_d.device) if labels is not None else None + + query_reps = self._encode_q(**queries) if queries is not None else None + docs_batch_reps = self._encode_d(**docs_batch) if docs_batch is not None else None + + pred, labels, inbatch_pred = self.prepare_outputs(query_reps, docs_batch_reps, labels) + inbatch_loss = self.inbatch_loss_fn(inbatch_pred, torch.eye(inbatch_pred.shape[0]).to(inbatch_pred.device)) if inbatch_pred is not None else 0. + + loss_value = loss(pred, labels, query_reps, docs_batch_reps) if labels is not None else loss(pred, None, query_reps, docs_batch_reps) + loss_value += inbatch_loss + return (loss_value, pred) AutoConfig.register("Sparse", SparseConfig) AutoModel.register(SparseConfig, Sparse) \ No newline at end of file