Skip to content

Commit

Permalink
all functional
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 13, 2024
1 parent 003c811 commit 713fb4e
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions rankers/modelling/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, confi

from .pyterrier.sparse import SparseTransformer
self.transformer_class = SparseTransformer

def forward(self,
loss = None,
queries = None,
docs_batch = None,
labels=None):
"""Compute the loss given (queries, docs, labels)"""
queries = {k: v.to(self.model.device) for k, v in queries.items()} if queries is not None else None
docs_batch = {k: v.to(self.model_d.device) for k, v in docs_batch.items()} if docs_batch is not None else None
labels = labels.to(self.model_d.device) if labels is not None else None

query_reps = self._encode_q(**queries) if queries is not None else None
docs_batch_reps = self._encode_d(**docs_batch) if docs_batch is not None else None

pred, labels, inbatch_pred = self.prepare_outputs(query_reps, docs_batch_reps, labels)
inbatch_loss = self.inbatch_loss_fn(inbatch_pred, torch.eye(inbatch_pred.shape[0]).to(inbatch_pred.device)) if inbatch_pred is not None else 0.

loss_value = loss(pred, labels, query_reps, docs_batch_reps) if labels is not None else loss(pred, None, query_reps, docs_batch_reps)
loss_value += inbatch_loss
return (loss_value, pred)

AutoConfig.register("Sparse", SparseConfig)
AutoModel.register(SparseConfig, Sparse)

0 comments on commit 713fb4e

Please sign in to comment.