diff --git a/rankers/train/training_arguments.py b/rankers/train/training_arguments.py index 21a13c6..31e45ab 100644 --- a/rankers/train/training_arguments.py +++ b/rankers/train/training_arguments.py @@ -34,7 +34,7 @@ class RankerTrainingArguments(TrainingArguments): default=None, metadata={"help": "Wandb project name"} ) - loss_fn : Optional[Union[str, callable]] = field( + loss_fn : Optional[str] = field( default='lce', metadata={"help": "Loss function to use"} )