diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index fc4e3e24..b47e941d 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -63,7 +63,7 @@ def __init__( learning_rate: float = 0.0003, l2_reg_lambda_linear: float = 1.0, state_features_only: bool = False, - loss_type: str = "mse", # one of the LOSS_TYPES names, e.g., mse, mae, xentropy + loss_type: str = "mse", # one of the LOSS_TYPES names: [mse, mae, cross_entropy] output_activation_name: str = "linear", use_batch_norm: bool = False, use_layer_norm: bool = False,