From 9fcc7a79124eeb964e5b5b60abac696c93548e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Wed, 13 Nov 2024 13:07:05 +0000 Subject: [PATCH] hmm need to sort the flow --- rankers/datasets/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rankers/datasets/dataset.py b/rankers/datasets/dataset.py index 302a074..1f5cf4c 100644 --- a/rankers/datasets/dataset.py +++ b/rankers/datasets/dataset.py @@ -134,7 +134,8 @@ def __getitem__(self, idx): texts, scores = texts[:self.group_size], scores[:self.group_size] else: texts, scores = zip(*random.sample(list(zip(doc_id_a_text + doc_id_b_text, doc_id_a_scores + doc_id_b_scores)), self.group_size)) - return (query, texts, scores) + return (query, texts, scores) + return (query, doc_id_a_text + doc_id_b_text, doc_id_a_scores + doc_id_b_scores) else: if len(doc_id_b_text) > (self.n_neg): doc_id_b_text = random.sample(doc_id_b_text, self.group_size) return (query, doc_id_a_text + doc_id_b_text)