Skip to content

Commit

Permalink
Merge pull request #711 from SWivid/fix_grad_accum
Browse files Browse the repository at this point in the history
0.4.0 fix gradient accumulation; change checkpointing logic to per_updates
  • Loading branch information
SWivid authored Jan 14, 2025
2 parents f992c4e + 0b11f7e commit 83efc3f
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 95 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "f5-tts"
version = "0.3.4"
version = "0.4.0"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
Expand Down
6 changes: 3 additions & 3 deletions src/f5_tts/configs/E2TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
Expand All @@ -39,6 +39,6 @@ model:

ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
6 changes: 3 additions & 3 deletions src/f5_tts/configs/E2TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0
bnb_optimizer: False
Expand All @@ -39,6 +39,6 @@ model:

ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
6 changes: 3 additions & 3 deletions src/f5_tts/configs/F5TTS_Base_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
Expand Down Expand Up @@ -42,6 +42,6 @@ model:

ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
6 changes: 3 additions & 3 deletions src/f5_tts/configs/F5TTS_Small_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
Expand Down Expand Up @@ -42,6 +42,6 @@ model:

ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
118 changes: 68 additions & 50 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import math
import os

import torch
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
last_per_steps=None,
last_per_updates=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
Expand All @@ -57,6 +58,11 @@ def __init__(
print(f"Using logger: {logger}")
self.log_samples = log_samples

if grad_accumulation_steps > 1 and self.is_main:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)

self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
Expand Down Expand Up @@ -102,7 +108,7 @@ def __init__(
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")

self.batch_size = batch_size
Expand Down Expand Up @@ -132,23 +138,23 @@ def __init__(
def is_main(self):
return self.accelerator.is_main_process

def save_checkpoint(self, step, last=False):
def save_checkpoint(self, update, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
step=step,
update=update,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at step {step}")
print(f"Saved last checkpoint at update {update}")
else:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")

def load_checkpoint(self):
if (
Expand Down Expand Up @@ -177,7 +183,14 @@ def load_checkpoint(self):
if self.is_main:
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])

if "step" in checkpoint:
if "update" in checkpoint or "step" in checkpoint:
# patch for backward compatibility, with before f992c4e
if "step" in checkpoint:
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
if self.grad_accumulation_steps > 1 and self.is_main:
print(
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
)
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
Expand All @@ -187,19 +200,19 @@ def load_checkpoint(self):
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
step = checkpoint["step"]
update = checkpoint["update"]
else:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
if k not in ["initted", "update", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
step = 0
update = 0

del checkpoint
gc.collect()
return step
return update

def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if self.log_samples:
Expand Down Expand Up @@ -248,25 +261,26 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int

# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_steps = (
warmup_updates = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
decay_steps = total_steps - warmup_steps
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
decay_updates = total_updates - warmup_updates
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual steps = 1 gpu steps / gpus
start_step = self.load_checkpoint()
global_step = start_step
) # actual multi_gpu updates = single_gpu updates / gpu nums
start_update = self.load_checkpoint()
global_update = start_update

if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
start_step = start_update * self.grad_accumulation_steps
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
Expand All @@ -276,23 +290,21 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar = tqdm(
skipped_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
initial=skipped_batch,
total=orig_epoch_step,
)
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
current_dataloader = skipped_dataloader
else:
progress_bar = tqdm(
train_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
)

for batch in progress_bar:
progress_bar_initial = 0
current_dataloader = train_dataloader

progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
)

for batch in current_dataloader:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
Expand All @@ -301,7 +313,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)

loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
Expand All @@ -318,18 +330,20 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
if self.is_main and self.accelerator.sync_gradients:
self.ema_model.update()

global_step += 1
global_update += 1
progress_bar.update(1)
progress_bar.set_postfix(update=str(global_update), loss=loss.item())

if self.accelerator.is_local_main_process:
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
)
if self.logger == "tensorboard":
self.writer.add_scalar("loss", loss.item(), global_step)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)

progress_bar.set_postfix(step=str(global_step), loss=loss.item())
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)

if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
self.save_checkpoint(global_step)
if global_update % self.save_per_updates == 0:
self.save_checkpoint(global_update)

if self.log_samples and self.accelerator.is_local_main_process:
ref_audio_len = mel_lengths[0]
Expand All @@ -355,12 +369,16 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()

torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)

if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
if global_update % self.last_per_updates == 0:
self.save_checkpoint(global_update, last=True)

self.save_checkpoint(global_step, last=True)
self.save_checkpoint(global_update, last=True)

self.accelerator.end_training()
4 changes: 2 additions & 2 deletions src/f5_tts/scripts/count_max_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
steps_per_epoch = updates_per_epoch * grad_accum
# steps_per_epoch = updates_per_epoch * grad_accum

# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")

# others
print(f"total {total_hours:.0f} hours")
Expand Down
12 changes: 6 additions & 6 deletions src/f5_tts/train/finetune_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def parse_args():

# num_warmup_updates = 300 for 5000 sample about 10 hours

# change save_per_updates , last_per_steps change this value what you need ,
# change save_per_updates , last_per_updates change this value what you need ,

parser = argparse.ArgumentParser(description="Train CFM Model")

Expand All @@ -44,9 +44,9 @@ def parse_args():
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
Expand All @@ -61,7 +61,7 @@ def parse_args():
parser.add_argument(
"--log_samples",
action="store_true",
help="Log inferenced samples per ckpt save steps",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
parser.add_argument(
Expand Down Expand Up @@ -156,7 +156,7 @@ def main():
wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id,
log_samples=args.log_samples,
last_per_steps=args.last_per_steps,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
)

Expand Down
Loading

0 comments on commit 83efc3f

Please sign in to comment.