Skip to content

Commit

Permalink
Keep Last N Checkpoints (#718)
Browse files Browse the repository at this point in the history
* Add checkpoint management feature

- Introduced `keep_last_n_checkpoints` parameter in configuration and training scripts to manage the number of recent checkpoints retained.
- Updated `finetune_cli.py`, `finetune_gradio.py`, and `trainer.py` to support this new parameter.
- Implemented logic to remove older checkpoints beyond the specified limit during training.
- Adjusted settings loading and saving to include the new checkpoint management option.

This enhancement improves the training process by preventing excessive storage usage from old checkpoints.
  • Loading branch information
hcsolakoglu authored Jan 15, 2025
1 parent 83efc3f commit 76b1b03
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/f5_tts/configs/E2TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
1 change: 1 addition & 0 deletions src/f5_tts/configs/E2TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
1 change: 1 addition & 0 deletions src/f5_tts/configs/F5TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
1 change: 1 addition & 0 deletions src/f5_tts/configs/F5TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
31 changes: 31 additions & 0 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ def __init__(
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
keep_last_n_checkpoints: int
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
):
# Validate keep_last_n_checkpoints
if not isinstance(keep_last_n_checkpoints, int):
raise ValueError("keep_last_n_checkpoints must be an integer")
if keep_last_n_checkpoints < -1:
raise ValueError(
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

if logger == "wandb" and not wandb.api.api_key:
Expand Down Expand Up @@ -134,6 +144,8 @@ def __init__(
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None

@property
def is_main(self):
return self.accelerator.is_main_process
Expand All @@ -154,7 +166,26 @@ def save_checkpoint(self, update, last=False):
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at update {update}")
else:
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
if self.keep_last_n_checkpoints == 0:
return

self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
if self.keep_last_n_checkpoints > 0:
# Get all checkpoint files except model_last.pt
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
]
# Sort by step number
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
# Remove old checkpoints if we have more than keep_last_n_checkpoints
while len(checkpoints) > self.keep_last_n_checkpoints:
oldest_checkpoint = checkpoints.pop(0)
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
print(f"Removed old checkpoint: {oldest_checkpoint}")

def load_checkpoint(self):
if (
Expand Down
7 changes: 7 additions & 0 deletions src/f5_tts/train/finetune_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def parse_args():
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
)

return parser.parse_args()

Expand Down Expand Up @@ -158,6 +164,7 @@ def main():
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
)

train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
Expand Down
22 changes: 21 additions & 1 deletion src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def save_settings(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
):
path_project = os.path.join(path_project_ckpts, project_name)
os.makedirs(path_project, exist_ok=True)
Expand All @@ -94,6 +95,7 @@ def save_settings(
"mixed_precision": mixed_precision,
"logger": logger,
"bnb_optimizer": ch_8bit_adam,
"keep_last_n_checkpoints": keep_last_n_checkpoints,
}
with open(file_setting, "w") as f:
json.dump(settings, f, indent=4)
Expand Down Expand Up @@ -126,6 +128,7 @@ def load_settings(project_name):
"mixed_precision": "none",
"logger": "wandb",
"bnb_optimizer": False,
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
}
return (
settings["exp_name"],
Expand All @@ -146,6 +149,7 @@ def load_settings(project_name):
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)

with open(file_setting, "r") as f:
Expand All @@ -154,6 +158,8 @@ def load_settings(project_name):
settings["logger"] = "wandb"
if "bnb_optimizer" not in settings:
settings["bnb_optimizer"] = False
if "keep_last_n_checkpoints" not in settings:
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
return (
Expand All @@ -175,6 +181,7 @@ def load_settings(project_name):
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)


Expand Down Expand Up @@ -390,6 +397,7 @@ def start_training(
stream=False,
logger="wandb",
ch_8bit_adam=False,
keep_last_n_checkpoints=-1,
):
global training_process, tts_api, stop_signal

Expand Down Expand Up @@ -451,7 +459,8 @@ def start_training(
f"--num_warmup_updates {num_warmup_updates} "
f"--save_per_updates {save_per_updates} "
f"--last_per_updates {last_per_updates} "
f"--dataset_name {dataset_name}"
f"--dataset_name {dataset_name} "
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
)

if finetune:
Expand Down Expand Up @@ -492,6 +501,7 @@ def start_training(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
)

try:
Expand Down Expand Up @@ -1564,6 +1574,13 @@ def get_audio_select(file_sample):
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
last_per_updates = gr.Number(label="Last per Updates", value=100)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
)

with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
Expand Down Expand Up @@ -1592,6 +1609,7 @@ def get_audio_select(file_sample):
mixed_precisionv,
cd_loggerv,
ch_8bit_adamv,
keep_last_n_checkpointsv,
) = load_settings(projects_selelect)
exp_name.value = exp_namev
learning_rate.value = learning_ratev
Expand All @@ -1611,6 +1629,7 @@ def get_audio_select(file_sample):
mixed_precision.value = mixed_precisionv
cd_logger.value = cd_loggerv
ch_8bit_adam.value = ch_8bit_adamv
keep_last_n_checkpoints.value = keep_last_n_checkpointsv

ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
Expand Down Expand Up @@ -1670,6 +1689,7 @@ def get_audio_select(file_sample):
ch_stream,
cd_logger,
ch_8bit_adam,
keep_last_n_checkpoints,
],
outputs=[txt_info_train, start_button, stop_button],
)
Expand Down
1 change: 1 addition & 0 deletions src/f5_tts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main(cfg):
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
)

train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
Expand Down

1 comment on commit 76b1b03

@ILG2021
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pr will cause gradio error for load_setting. return one value but outputs many componenet.

Please sign in to comment.