Skip to content

Commit

Permalink
Enhance checkpoint management clarity and validation
Browse files Browse the repository at this point in the history
- Updated `keep_last_n_checkpoints` parameter descriptions in `E2TTS` and `F5TTS` YAML files to clarify that setting it to 0 disables retention of recent checkpoints.
- Modified `trainer.py` to validate `keep_last_n_checkpoints`, ensuring it must be 0 or positive.
- Adjusted help text in `finetune_cli.py` to reflect the new validation rules.
- Enhanced user interface in `finetune_gradio.py` to enforce minimum value for checkpoint retention.

These changes improve the usability and understanding of checkpoint management settings.
  • Loading branch information
hcsolakoglu committed Jan 13, 2025
1 parent 2bda945 commit 5ae750f
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/f5_tts/configs/E2TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 or negative to disable (keep all checkpoints)
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
2 changes: 1 addition & 1 deletion src/f5_tts/configs/E2TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 or negative to disable (keep all checkpoints)
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
2 changes: 1 addition & 1 deletion src/f5_tts/configs/F5TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 or negative to disable (keep all checkpoints)
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
2 changes: 1 addition & 1 deletion src/f5_tts/configs/F5TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 or negative to disable (keep all checkpoints)
keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
6 changes: 5 additions & 1 deletion src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ def __init__(
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
keep_last_n_checkpoints: int | None = None, # number of recent checkpoints to keep (None or <=0 to keep all)
keep_last_n_checkpoints: int | None = None, # number of recent checkpoints to keep (None or 0 to keep all)
):
# Validate keep_last_n_checkpoints
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints < 0:
raise ValueError("keep_last_n_checkpoints must be 0 or positive")

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

if logger == "wandb" and not wandb.api.api_key:
Expand Down
2 changes: 1 addition & 1 deletion src/f5_tts/train/finetune_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse_args():
"--keep_last_n_checkpoints",
type=int,
default=None,
help="Number of recent checkpoints to keep (excluding model_last.pt). Set to 0 or negative to disable (keep all checkpoints)",
help="Number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Must be 0 or positive.",
)

return parser.parse_args()
Expand Down
5 changes: 4 additions & 1 deletion src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,10 @@ def get_audio_select(file_sample):
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=0,
info="Set to 0 or negative to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept."
minimum=0,
step=1,
precision=0,
info="Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept."
)

with gr.Row():
Expand Down

0 comments on commit 5ae750f

Please sign in to comment.