diff --git a/rankers/train/data_arguments.py b/rankers/train/data_arguments.py index 7676e71..29a2524 100644 --- a/rankers/train/data_arguments.py +++ b/rankers/train/data_arguments.py @@ -45,7 +45,7 @@ def __post_init__(self): except Exception as e: raise ValueError(f"Unable to load ir_dataset: {e}") assert self.training_dataset_file.endswith('jsonl') or self.training_dataset_file.endswith('jsonl.gz'), "Training dataset should be a JSONL file" - self.training_data = load_json(self.training_dataset_file) + self.training_data = pd.read_json(self.training_dataset_file, lines=True, orient='records') if self.teacher_file: assert self.teacher_file.endswith('json') or self.teacher_file.endswith('json.gz'), "Teacher file should be a JSON file" self.teacher_data = load_json(self.teacher_file)