diff --git a/dominoes/datasets/base.py b/dominoes/datasets/base.py index 73e8265..4a4c8f9 100644 --- a/dominoes/datasets/base.py +++ b/dominoes/datasets/base.py @@ -332,7 +332,7 @@ def reward_function(self, choices, batch, **kwargs): if self.task == "sequencer": return self._measurereward_sequencer(choices, batch, **kwargs) elif self.task == "sorting": - return self._measurereward_sorter(choices, batch, **kwargs): + return self._measurereward_sorter(choices, batch, **kwargs) else: raise ValueError(f"task {self.task} not recognized")