Skip to content

Commit

Permalink
fix: distribute properly models
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Dec 7, 2023
1 parent 2c0e23d commit fbc9b5c
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
DuoTwoStageRetriever,
Retriever,
)
from xpmir.utils.utils import easylog

logger = easylog()


class CrossScorer(LearnableScorer, DistributableModel):
Expand Down Expand Up @@ -51,7 +54,13 @@ def forward(self, inputs: BaseRecords, info: TrainerContext = None):
return self.classifier(pairs).squeeze(1)

def distribute_models(self, update):
self.encoder.model = update(self.encoder.model)
if isinstance(self.encoder, DistributableModel):
self.encoder = self.distribute_models(self.encoder, update)
else:
logger.warning(
"Cross-encoder encoder is not distributable: "
"keeping it on one device"
)


class DuoCrossScorer(DuoLearnableScorer, DistributableModel):
Expand Down

0 comments on commit fbc9b5c

Please sign in to comment.