diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py index e475099ec..43e090cde 100644 --- a/paxml/tasks_lib.py +++ b/paxml/tasks_lib.py @@ -1787,7 +1787,7 @@ def _apply_init_checkpoint_rule( ) # Initialize with a dummy seed var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( - inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST) + inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) ckpt_train_state = ckpt_task.create_train_state_padded_shapes( var_weight_hparams) train_state_pspecs = ckpt_task.create_train_state_partition_specs(