diff --git a/pyproject.toml b/pyproject.toml index de04c3add..79120696e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.3.4" +version = "0.4.0" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/configs/E2TTS_Base_train.yaml b/src/f5_tts/configs/E2TTS_Base_train.yaml index c42514ef9..5e239df24 100644 --- a/src/f5_tts/configs/E2TTS_Base_train.yaml +++ b/src/f5_tts/configs/E2TTS_Base_train.yaml @@ -12,7 +12,7 @@ datasets: optim: epochs: 15 learning_rate: 7.5e-5 - num_warmup_updates: 20000 # warmup steps + num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not @@ -39,7 +39,7 @@ model: ckpts: logger: wandb # wandb | tensorboard | None - save_per_updates: 50000 # save checkpoint per steps - last_per_steps: 5000 # save last checkpoint per steps + save_per_updates: 50000 # save checkpoint per updates + last_per_updates: 5000 # save last checkpoint per updates keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/E2TTS_Small_train.yaml b/src/f5_tts/configs/E2TTS_Small_train.yaml index f99c1ae48..5e6b35673 100644 --- a/src/f5_tts/configs/E2TTS_Small_train.yaml +++ b/src/f5_tts/configs/E2TTS_Small_train.yaml @@ -12,7 +12,7 @@ datasets: optim: epochs: 15 learning_rate: 7.5e-5 - num_warmup_updates: 20000 # warmup steps + num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 bnb_optimizer: False @@ -39,7 +39,7 @@ model: ckpts: logger: wandb # wandb | tensorboard | None - save_per_updates: 50000 # save checkpoint per steps - last_per_steps: 5000 # save last checkpoint per steps + save_per_updates: 50000 # save checkpoint per updates + last_per_updates: 5000 # save last checkpoint per updates keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Base_train.yaml b/src/f5_tts/configs/F5TTS_Base_train.yaml index 8ac3ffa45..88c677257 100644 --- a/src/f5_tts/configs/F5TTS_Base_train.yaml +++ b/src/f5_tts/configs/F5TTS_Base_train.yaml @@ -12,7 +12,7 @@ datasets: optim: epochs: 15 learning_rate: 7.5e-5 - num_warmup_updates: 20000 # warmup steps + num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not @@ -42,7 +42,7 @@ model: ckpts: logger: wandb # wandb | tensorboard | None - save_per_updates: 50000 # save checkpoint per steps - last_per_steps: 5000 # save last checkpoint per steps + save_per_updates: 50000 # save checkpoint per updates + last_per_updates: 5000 # save last checkpoint per updates keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Small_train.yaml b/src/f5_tts/configs/F5TTS_Small_train.yaml index dd8370885..b1fa3c68d 100644 --- a/src/f5_tts/configs/F5TTS_Small_train.yaml +++ b/src/f5_tts/configs/F5TTS_Small_train.yaml @@ -12,7 +12,7 @@ datasets: optim: epochs: 15 learning_rate: 7.5e-5 - num_warmup_updates: 20000 # warmup steps + num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps max_grad_norm: 1.0 # gradient clipping bnb_optimizer: False # use bnb 8bit AdamW optimizer or not @@ -42,7 +42,7 @@ model: ckpts: logger: wandb # wandb | tensorboard | None - save_per_updates: 50000 # save checkpoint per steps - last_per_steps: 5000 # save last checkpoint per steps + save_per_updates: 50000 # save checkpoint per updates + last_per_updates: 5000 # save last checkpoint per updates keep_last_n_checkpoints: 0 # number of recent checkpoints to keep (excluding model_last.pt). Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 162cb38bc..4784a2b7b 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc +import math import os import torch @@ -42,7 +43,7 @@ def __init__( wandb_run_name="test_run", wandb_resume_id: str = None, log_samples: bool = False, - last_per_steps=None, + last_per_updates=None, accelerate_kwargs: dict = dict(), ema_kwargs: dict = dict(), bnb_optimizer: bool = False, @@ -62,6 +63,11 @@ def __init__( print(f"Using logger: {logger}") self.log_samples = log_samples + if grad_accumulation_steps > 1 and self.is_main: + print( + "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e" + ) + self.accelerator = Accelerator( log_with=logger if logger == "wandb" else None, kwargs_handlers=[ddp_kwargs], @@ -107,7 +113,7 @@ def __init__( self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates - self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps) + self.last_per_updates = default(last_per_updates, save_per_updates) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") self.batch_size = batch_size @@ -139,7 +145,7 @@ def __init__( def is_main(self): return self.accelerator.is_main_process - def save_checkpoint(self, step, last=False): + def save_checkpoint(self, update, last=False): self.accelerator.wait_for_everyone() if self.is_main: checkpoint = dict( @@ -147,15 +153,15 @@ def save_checkpoint(self, step, last=False): optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), ema_model_state_dict=self.ema_model.state_dict(), scheduler_state_dict=self.scheduler.state_dict(), - step=step, + update=update, ) if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) if last: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") - print(f"Saved last checkpoint at step {step}") + print(f"Saved last checkpoint at update {update}") else: - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") + self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") # Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive if self.keep_last_n_checkpoints is not None and self.keep_last_n_checkpoints > 0: # Get all checkpoint files except model_last.pt @@ -197,7 +203,14 @@ def load_checkpoint(self): if self.is_main: self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"]) - if "step" in checkpoint: + if "update" in checkpoint or "step" in checkpoint: + # patch for backward compatibility, with before f992c4e + if "step" in checkpoint: + checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps + if self.grad_accumulation_steps > 1 and self.is_main: + print( + "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour." + ) # patch for backward compatibility, 305e3ea for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: @@ -207,19 +220,19 @@ def load_checkpoint(self): self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) - step = checkpoint["step"] + update = checkpoint["update"] else: checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() - if k not in ["initted", "step"] + if k not in ["initted", "update", "step"] } self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) - step = 0 + update = 0 del checkpoint gc.collect() - return step + return update def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if self.log_samples: @@ -268,25 +281,26 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices - warmup_steps = ( + warmup_updates = ( self.num_warmup_updates * self.accelerator.num_processes ) # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes - total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps - decay_steps = total_steps - warmup_steps - warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) - decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) + total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs + decay_updates = total_updates - warmup_updates + warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates) + decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates) self.scheduler = SequentialLR( - self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] + self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates] ) train_dataloader, self.scheduler = self.accelerator.prepare( train_dataloader, self.scheduler - ) # actual steps = 1 gpu steps / gpus - start_step = self.load_checkpoint() - global_step = start_step + ) # actual multi_gpu updates = single_gpu updates / gpu nums + start_update = self.load_checkpoint() + global_update = start_update if exists(resumable_with_seed): orig_epoch_step = len(train_dataloader) + start_step = start_update * self.grad_accumulation_steps skipped_epoch = int(start_step // orig_epoch_step) skipped_batch = start_step % orig_epoch_step skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch) @@ -296,23 +310,21 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int for epoch in range(skipped_epoch, self.epochs): self.model.train() if exists(resumable_with_seed) and epoch == skipped_epoch: - progress_bar = tqdm( - skipped_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", - unit="step", - disable=not self.accelerator.is_local_main_process, - initial=skipped_batch, - total=orig_epoch_step, - ) + progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps) + current_dataloader = skipped_dataloader else: - progress_bar = tqdm( - train_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", - unit="step", - disable=not self.accelerator.is_local_main_process, - ) - - for batch in progress_bar: + progress_bar_initial = 0 + current_dataloader = train_dataloader + + progress_bar = tqdm( + range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)), + desc=f"Epoch {epoch+1}/{self.epochs}", + unit="update", + disable=not self.accelerator.is_local_main_process, + initial=progress_bar_initial, + ) + + for batch in current_dataloader: with self.accelerator.accumulate(self.model): text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) @@ -321,7 +333,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int # TODO. add duration predictor training if self.duration_predictor is not None and self.accelerator.is_local_main_process: dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) - self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) + self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update) loss, cond, pred = self.model( mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler @@ -338,18 +350,20 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int if self.is_main and self.accelerator.sync_gradients: self.ema_model.update() - global_step += 1 + global_update += 1 + progress_bar.update(1) + progress_bar.set_postfix(update=str(global_update), loss=loss.item()) if self.accelerator.is_local_main_process: - self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) + self.accelerator.log( + {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update + ) if self.logger == "tensorboard": - self.writer.add_scalar("loss", loss.item(), global_step) - self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step) - - progress_bar.set_postfix(step=str(global_step), loss=loss.item()) + self.writer.add_scalar("loss", loss.item(), global_update) + self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update) - if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: - self.save_checkpoint(global_step) + if global_update % self.save_per_updates == 0: + self.save_checkpoint(global_update) if self.log_samples and self.accelerator.is_local_main_process: ref_audio_len = mel_lengths[0] @@ -375,12 +389,16 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu() ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu() - torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) - torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) + torchaudio.save( + f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate + ) + torchaudio.save( + f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate + ) - if global_step % self.last_per_steps == 0: - self.save_checkpoint(global_step, last=True) + if global_update % self.last_per_updates == 0: + self.save_checkpoint(global_update, last=True) - self.save_checkpoint(global_step, last=True) + self.save_checkpoint(global_update, last=True) self.accelerator.end_training() diff --git a/src/f5_tts/scripts/count_max_epoch.py b/src/f5_tts/scripts/count_max_epoch.py index 7cd7332df..18d36df33 100644 --- a/src/f5_tts/scripts/count_max_epoch.py +++ b/src/f5_tts/scripts/count_max_epoch.py @@ -20,13 +20,13 @@ mini_batch_frames = frames_per_gpu * grad_accum * gpus mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 updates_per_epoch = total_hours / mini_batch_hours -steps_per_epoch = updates_per_epoch * grad_accum +# steps_per_epoch = updates_per_epoch * grad_accum # result epochs = wanted_max_updates / updates_per_epoch print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})") print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") -print(f" or approx. 0/{steps_per_epoch:.0f} steps") +# print(f" or approx. 0/{steps_per_epoch:.0f} steps") # others print(f"total {total_hours:.0f} hours") diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 8da3ff779..638d5da02 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -27,7 +27,7 @@ def parse_args(): # num_warmup_updates = 300 for 5000 sample about 10 hours - # change save_per_updates , last_per_steps change this value what you need , + # change save_per_updates , last_per_updates change this value what you need , parser = argparse.ArgumentParser(description="Train CFM Model") @@ -44,9 +44,9 @@ def parse_args(): parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") - parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps") - parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") - parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps") + parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates") + parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates") + parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates") parser.add_argument("--finetune", action="store_true", help="Use Finetune") parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") parser.add_argument( @@ -61,7 +61,7 @@ def parse_args(): parser.add_argument( "--log_samples", action="store_true", - help="Log inferenced samples per ckpt save steps", + help="Log inferenced samples per ckpt save updates", ) parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger") parser.add_argument( @@ -162,7 +162,7 @@ def main(): wandb_run_name=args.exp_name, wandb_resume_id=wandb_resume_id, log_samples=args.log_samples, - last_per_steps=args.last_per_steps, + last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, keep_last_n_checkpoints=args.keep_last_n_checkpoints, ) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 09260328f..a96675412 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -62,7 +62,7 @@ def save_settings( epochs, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, finetune, file_checkpoint_train, tokenizer_type, @@ -87,7 +87,7 @@ def save_settings( "epochs": epochs, "num_warmup_updates": num_warmup_updates, "save_per_updates": save_per_updates, - "last_per_steps": last_per_steps, + "last_per_updates": last_per_updates, "finetune": finetune, "file_checkpoint_train": file_checkpoint_train, "tokenizer_type": tokenizer_type, @@ -120,7 +120,7 @@ def load_settings(project_name): "epochs": 100, "num_warmup_updates": 2, "save_per_updates": 300, - "last_per_steps": 100, + "last_per_updates": 100, "finetune": True, "file_checkpoint_train": "", "tokenizer_type": "pinyin", @@ -141,7 +141,7 @@ def load_settings(project_name): settings["epochs"], settings["num_warmup_updates"], settings["save_per_updates"], - settings["last_per_steps"], + settings["last_per_updates"], settings["finetune"], settings["file_checkpoint_train"], settings["tokenizer_type"], @@ -160,6 +160,8 @@ def load_settings(project_name): settings["bnb_optimizer"] = False if "keep_last_n_checkpoints" not in settings: settings["keep_last_n_checkpoints"] = 0 + if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e + settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"] return ( settings["exp_name"], settings["learning_rate"], @@ -171,7 +173,7 @@ def load_settings(project_name): settings["epochs"], settings["num_warmup_updates"], settings["save_per_updates"], - settings["last_per_steps"], + settings["last_per_updates"], settings["finetune"], settings["file_checkpoint_train"], settings["tokenizer_type"], @@ -386,7 +388,7 @@ def start_training( epochs=11, num_warmup_updates=200, save_per_updates=400, - last_per_steps=800, + last_per_updates=800, finetune=True, file_checkpoint_train="", tokenizer_type="pinyin", @@ -456,7 +458,7 @@ def start_training( f"--epochs {epochs} " f"--num_warmup_updates {num_warmup_updates} " f"--save_per_updates {save_per_updates} " - f"--last_per_steps {last_per_steps} " + f"--last_per_updates {last_per_updates} " f"--dataset_name {dataset_name} " f"--keep_last_n_checkpoints {keep_last_n_checkpoints}" ) @@ -491,7 +493,7 @@ def start_training( epochs, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, finetune, file_checkpoint_train, tokenizer_type, @@ -890,7 +892,7 @@ def calculate_train( learning_rate, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, finetune, ): path_project = os.path.join(path_data, name_project) @@ -902,7 +904,7 @@ def calculate_train( max_samples, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, "project not found !", learning_rate, ) @@ -950,14 +952,14 @@ def calculate_train( num_warmup_updates = int(samples * 0.05) save_per_updates = int(samples * 0.10) - last_per_steps = int(save_per_updates * 0.25) + last_per_updates = int(save_per_updates * 0.25) max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples) num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates) save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) - last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps) - if last_per_steps <= 0: - last_per_steps = 2 + last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates) + if last_per_updates <= 0: + last_per_updates = 2 total_hours = hours mel_hop_length = 256 @@ -988,7 +990,7 @@ def calculate_train( max_samples, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, samples, learning_rate, int(epochs), @@ -1540,7 +1542,7 @@ def get_audio_select(file_sample): with gr.TabItem("Train Data"): gr.Markdown("""```plaintext -The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed. +The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed. If you encounter a memory error, try reducing the batch size per GPU to a smaller number. ```""") with gr.Row(): @@ -1571,7 +1573,7 @@ def get_audio_select(file_sample): with gr.Row(): save_per_updates = gr.Number(label="Save per Updates", value=300) - last_per_steps = gr.Number(label="Last per Steps", value=100) + last_per_updates = gr.Number(label="Last per Updates", value=100) keep_last_n_checkpoints = gr.Number( label="Keep Last N Checkpoints", value=0, @@ -1581,6 +1583,7 @@ def get_audio_select(file_sample): info="Set to 0 to disable (keep all checkpoints). Positive numbers limit the number of checkpoints kept." ) + with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none") @@ -1600,7 +1603,7 @@ def get_audio_select(file_sample): epochsv, num_warmupv_updatesv, save_per_updatesv, - last_per_stepsv, + last_per_updatesv, finetunev, file_checkpoint_trainv, tokenizer_typev, @@ -1620,7 +1623,7 @@ def get_audio_select(file_sample): epochs.value = epochsv num_warmup_updates.value = num_warmupv_updatesv save_per_updates.value = save_per_updatesv - last_per_steps.value = last_per_stepsv + last_per_updates.value = last_per_updatesv ch_finetune.value = finetunev file_checkpoint_train.value = file_checkpoint_trainv tokenizer_type.value = tokenizer_typev @@ -1679,7 +1682,7 @@ def get_audio_select(file_sample): epochs, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, ch_finetune, file_checkpoint_train, tokenizer_type, @@ -1703,7 +1706,7 @@ def get_audio_select(file_sample): learning_rate, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, ch_finetune, ], outputs=[ @@ -1711,7 +1714,7 @@ def get_audio_select(file_sample): max_samples, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, lb_samples, learning_rate, epochs, @@ -1734,7 +1737,7 @@ def setup_load_settings(): epochs, num_warmup_updates, save_per_updates, - last_per_steps, + last_per_updates, ch_finetune, file_checkpoint_train, tokenizer_type, diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 9605b885b..eaaaa2452 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -55,7 +55,7 @@ def main(cfg): wandb_project="CFM-TTS", wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, - last_per_steps=cfg.ckpts.last_per_steps, + last_per_updates=cfg.ckpts.last_per_updates, log_samples=True, bnb_optimizer=cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type,