Skip to content

Commit

Permalink
hmhm
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 19, 2024
1 parent a3e553e commit e376bd9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
6 changes: 3 additions & 3 deletions rankers/modelling/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
else: self.model_d = self.model if config.model_tied else deepcopy(self.model)
self.pooling = {
'mean': lambda x: x.mean(dim=1),
'cls' : lambda x: x[:, 0],
'cls' : lambda x: x[:, 0, :],
'late_interaction': lambda x: x,
'none': lambda x: x,
}[config.pooling_type]
Expand Down Expand Up @@ -162,7 +162,7 @@ def prepare_outputs(self, query_reps, docs_batch_reps, labels=None):
return pred, labels, inbatch_pred

def _cls(self, x : torch.Tensor) -> torch.Tensor:
return self.pooler(x[:, 0])
return self.pooler(x[:, 0, :])

def _mean(self, x : torch.Tensor) -> torch.Tensor:
return self.pooler(x.mean(dim=1))
Expand All @@ -185,7 +185,7 @@ def forward(self,

query_reps = self._encode_q(**queries) if queries is not None else None
docs_batch_reps = self._encode_d(**docs_batch) if docs_batch is not None else None

breakpoint()
pred, labels, inbatch_pred = self.prepare_outputs(query_reps, docs_batch_reps, labels)
inbatch_loss = self.inbatch_loss_fn(inbatch_pred, torch.eye(inbatch_pred.shape[0]).to(inbatch_pred.device)) if inbatch_pred is not None else 0.

Expand Down
1 change: 0 additions & 1 deletion rankers/train/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def batched_dot_product(a: Tensor, b: Tensor):
"""
if len(b.shape) == 2:
return torch.matmul(a, b.transpose(0, 1))
breakpoint()
# Ensure `a` is of shape (batch_size, 1, vector_dim)
if len(a.shape) == 2:
a = a.unsqueeze(1)
Expand Down

0 comments on commit e376bd9

Please sign in to comment.