From ab4dd62748cd3239aec61cc9a2f084d334608d2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Tue, 12 Nov 2024 11:50:47 +0000 Subject: [PATCH] should fix? --- rankers/datasets/loader.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/rankers/datasets/loader.py b/rankers/datasets/loader.py index 108154e..0c3317c 100644 --- a/rankers/datasets/loader.py +++ b/rankers/datasets/loader.py @@ -17,12 +17,14 @@ def __call__(self, batch) -> dict: batch_queries = [] batch_docs = [] batch_scores = [] - for (q, dx, *args) in batch: + for elt in batch: + q = elt[0] + dx = elt[1] batch_queries.append(q) batch_docs.extend(dx) - if len(args) == 0: + if len(elt) < 3: continue - batch_scores.extend(args[0]) + batch_scores.extend(elt[2]) tokenized_queries = self.tokenizer( batch_queries, @@ -61,12 +63,14 @@ def __call__(self, batch) -> dict: batch_queries = [] batch_docs = [] batch_scores = [] - for (q, dx, *args) in batch: + for elt in batch: + q = elt[0] + dx = elt[1] batch_queries.extend([q]*len(dx)) batch_docs.extend(dx) - if len(args) == 0: + if len(elt) < 3: continue - batch_scores.extend(args[0]) + batch_scores.extend(elt[2]) tokenized_sequences = self.tokenizer( batch_queries,