diff --git a/hparams.py b/hparams.py index fe898c8d..886aeed9 100644 --- a/hparams.py +++ b/hparams.py @@ -53,6 +53,7 @@ voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length voc_seq_len = hop_length * 5 # must be a multiple of hop_length voc_clip_grad_norm = 4 # set to None if no gradient clipping needed +voc_use_mixed_precision = True # Enable mixed precision # Generating / Synthesizing voc_gen_batched = True # very fast (realtime+) single utterance batched generation @@ -91,7 +92,7 @@ tts_clip_grad_norm = 1.0 # clips the gradient norm to prevent explosion - set to None if not needed tts_checkpoint_every = 2_000 # checkpoints the model every X steps # TODO: tts_phoneme_prob = 0.0 # [0 <-> 1] probability for feeding model phonemes vrs graphemes - +tts_use_mixed_precision = False # Enable mixed precision # ------------------------------------------------------------------------------------------------------------------# diff --git a/train_tacotron.py b/train_tacotron.py index 0976751b..79bc2082 100644 --- a/train_tacotron.py +++ b/train_tacotron.py @@ -62,7 +62,11 @@ def main(): stop_threshold=hp.tts_stop_threshold).to(device) optimizer = optim.Adam(model.parameters()) - restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True) + scaler = torch.cuda.amp.GradScaler() if hp.tts_use_mixed_precision and device.type == 'cuda' else None + + print('Using mixed precision:', scaler is not None) + + restore_checkpoint('tts', paths, model, optimizer, scaler, create_if_missing=True) if not force_gta: for i, session in enumerate(hp.tts_schedule): @@ -95,7 +99,7 @@ def main(): ('Outputs/Step (r)', model.r)]) train_set, attn_example = get_tts_datasets(paths.data, batch_size, r) - tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example) + tts_train_loop(paths, model, optimizer, scaler, train_set, lr, training_steps, attn_example) print('Training Complete.') print('To continue training increase tts_total_steps in hparams.py or use --force_train\n') @@ -109,7 +113,7 @@ def main(): print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n') -def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example): +def tts_train_loop(paths: Paths, model: Tacotron, optimizer, scaler, train_set, lr, train_steps, attn_example): device = next(model.parameters()).device # use same device as model parameters for g in optimizer.param_groups: g['lr'] = lr @@ -126,26 +130,40 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai for i, (x, m, ids, _) in enumerate(train_set, 1): x, m = x.to(device), m.to(device) - - # Parallelize model onto GPUS using workaround due to python bug - if device.type == 'cuda' and torch.cuda.device_count() > 1: - m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m) - else: - m1_hat, m2_hat, attention = model(x, m) - - m1_loss = F.l1_loss(m1_hat, m) - m2_loss = F.l1_loss(m2_hat, m) - - loss = m1_loss + m2_loss - + optimizer.zero_grad() - loss.backward() - if hp.tts_clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) - if np.isnan(grad_norm): - print('grad_norm was NaN!') - - optimizer.step() + + with torch.cuda.amp.autocast(enabled=scaler is not None): + # Parallelize model onto GPUS using workaround due to python bug + if device.type == 'cuda' and torch.cuda.device_count() > 1: + m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m) + else: + m1_hat, m2_hat, attention = model(x, m) + + m1_loss = F.l1_loss(m1_hat, m) + m2_loss = F.l1_loss(m2_hat, m) + + loss = m1_loss + m2_loss + + if scaler is not None: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + + if hp.tts_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) + if np.isnan(grad_norm): + print('grad_norm was NaN!') + + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + if hp.tts_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) + if np.isnan(grad_norm): + print('grad_norm was NaN!') + + optimizer.step() running_loss += loss.item() avg_loss = running_loss / i @@ -157,20 +175,19 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai if step % hp.tts_checkpoint_every == 0: ckpt_name = f'taco_step{k}K' - save_checkpoint('tts', paths, model, optimizer, + save_checkpoint('tts', paths, model, optimizer, scaler, name=ckpt_name, is_silent=True) if attn_example in ids: idx = ids.index(attn_example) save_attention(np_now(attention[idx][:, :160]), paths.tts_attention/f'{step}') - save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot/f'{step}', 600) msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | ' stream(msg) # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts - save_checkpoint('tts', paths, model, optimizer, is_silent=True) + save_checkpoint('tts', paths, model, optimizer, scaler, is_silent=True) model.log(paths.tts_log, msg) print(' ') diff --git a/train_wavernn.py b/train_wavernn.py index 1acb1fdf..80c8ec14 100644 --- a/train_wavernn.py +++ b/train_wavernn.py @@ -68,7 +68,11 @@ def main(): assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length optimizer = optim.Adam(voc_model.parameters()) - restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True) + scaler = torch.cuda.amp.GradScaler() if hp.voc_use_mixed_precision and device.type == 'cuda' else None + + print('Using mixed precision:', scaler is not None) + + restore_checkpoint('voc', paths, voc_model, optimizer, scaler, create_if_missing=True) train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta) @@ -82,13 +86,13 @@ def main(): loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss - voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps) + voc_train_loop(paths, voc_model, loss_func, optimizer, scaler, train_set, test_set, lr, total_steps) print('Training Complete.') print('To continue training increase voc_total_steps in hparams.py or use --force_train') -def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps): +def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, scaler, train_set, test_set, lr, total_steps): # Use same device as model parameters device = next(model.parameters()).device @@ -105,30 +109,43 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set for i, (x, y, m) in enumerate(train_set, 1): x, m, y = x.to(device), m.to(device), y.to(device) - # Parallelize model onto GPUS using workaround due to python bug - if device.type == 'cuda' and torch.cuda.device_count() > 1: - y_hat = data_parallel_workaround(model, x, m) - else: - y_hat = model(x, m) + optimizer.zero_grad() - if model.mode == 'RAW': - y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + with torch.cuda.amp.autocast(enabled=scaler is not None): + # Parallelize model onto GPUS using workaround due to python bug + if device.type == 'cuda' and torch.cuda.device_count() > 1: + y_hat = data_parallel_workaround(model, x, m) + else: + y_hat = model(x, m) - elif model.mode == 'MOL': - y = y.float() + if model.mode == 'RAW': + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) - y = y.unsqueeze(-1) + elif model.mode == 'MOL': + y = y.float() + y = y.unsqueeze(-1) - loss = loss_func(y_hat, y) + loss = loss_func(y_hat, y) - optimizer.zero_grad() - loss.backward() - if hp.voc_clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm) - if np.isnan(grad_norm): - print('grad_norm was NaN!') - optimizer.step() + if scaler is not None: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + + if hp.voc_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) + if np.isnan(grad_norm): + print('grad_norm was NaN!') + + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + if hp.voc_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm) + if np.isnan(grad_norm): + print('grad_norm was NaN!') + optimizer.step() running_loss += loss.item() avg_loss = running_loss / i @@ -142,7 +159,7 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, hp.voc_target, hp.voc_overlap, paths.voc_output) ckpt_name = f'wave_step{k}K' - save_checkpoint('voc', paths, model, optimizer, + save_checkpoint('voc', paths, model, optimizer, scaler, name=ckpt_name, is_silent=True) msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | ' @@ -150,7 +167,7 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts - save_checkpoint('voc', paths, model, optimizer, is_silent=True) + save_checkpoint('voc', paths, model, optimizer, scaler, is_silent=True) model.log(paths.voc_log, msg) print(' ') diff --git a/utils/checkpoints.py b/utils/checkpoints.py index b2e64b17..df4beb37 100644 --- a/utils/checkpoints.py +++ b/utils/checkpoints.py @@ -15,18 +15,20 @@ def get_checkpoint_paths(checkpoint_type: str, paths: Paths): if checkpoint_type is 'tts': weights_path = paths.tts_latest_weights optim_path = paths.tts_latest_optim + scaler_path = paths.tts_latest_scaler checkpoint_path = paths.tts_checkpoints elif checkpoint_type is 'voc': weights_path = paths.voc_latest_weights optim_path = paths.voc_latest_optim + scaler_path = paths.voc_latest_scaler checkpoint_path = paths.voc_checkpoints else: raise NotImplementedError - return weights_path, optim_path, checkpoint_path + return weights_path, optim_path, scaler_path, checkpoint_path -def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, +def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, scaler, *, name=None, is_silent=False): """Saves the training session to disk. @@ -34,17 +36,18 @@ def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, paths: Provides information about the different paths to use. model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from. optimizer: An optmizer to save the state of (momentum, etc). + scaler: A scaler to save the state of (mixed precision training). name: If provided, will name to a checkpoint with the given name. Note that regardless of whether this is provided or not, this function will always update the files specified in `paths` that give the location of the latest weights and optimizer state. Saving a named checkpoint happens in addition to this update. """ - def helper(path_dict, is_named): + def helper(required_path_dict, optional_path_dict, is_named): s = 'named' if is_named else 'latest' - num_exist = sum(p.exists() for p in path_dict.values()) + num_exist = sum(p.exists() for p in required_path_dict.values()) - if num_exist not in (0,2): + if num_exist not in (0,len(required_path_dict)): # Checkpoint broken raise FileNotFoundError( f'We expected either both or no files in the {s} checkpoint to ' @@ -52,31 +55,45 @@ def helper(path_dict, is_named): if num_exist == 0: if not is_silent: print(f'Creating {s} checkpoint...') - for p in path_dict.values(): + for p in required_path_dict.values(): p.parent.mkdir(parents=True, exist_ok=True) else: if not is_silent: print(f'Saving to existing {s} checkpoint...') - if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}') - model.save(path_dict['w']) - if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}') - torch.save(optimizer.state_dict(), path_dict['o']) + if not is_silent: print(f'Saving {s} weights: {required_path_dict["w"]}') + model.save(required_path_dict['w']) + if not is_silent: print(f'Saving {s} optimizer state: {required_path_dict["o"]}') + torch.save(optimizer.state_dict(), required_path_dict['o']) + + for p in optional_path_dict.values(): + if not p.exists(): + p.parent.mkdir(parents=True, exist_ok=True) + + if scaler: + if not is_silent: print(f'Saving {s} scaler state: {optional_path_dict["s"]}') + torch.save(scaler.state_dict(), optional_path_dict['s']) - weights_path, optim_path, checkpoint_path = \ + weights_path, optim_path, scaler_path, checkpoint_path = \ get_checkpoint_paths(checkpoint_type, paths) - latest_paths = {'w': weights_path, 'o': optim_path} - helper(latest_paths, False) + latest_required_paths = {'w': weights_path, 'o': optim_path} + latest_optional_paths = {'s': scaler_path} + + helper(latest_required_paths, latest_optional_paths, False) if name: - named_paths = { + named_required_paths = { 'w': checkpoint_path/f'{name}_weights.pyt', 'o': checkpoint_path/f'{name}_optim.pyt', } - helper(named_paths, True) + + named_optional_paths = { + 's': checkpoint_path/f'{name}_scaler.pyt', + } + helper(named_required_paths, named_optional_paths, True) -def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, +def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, scaler, *, name=None, create_if_missing=False): """Restores from a training session saved to disk. @@ -88,6 +105,7 @@ def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, paths: Provides information about the different paths to use. model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from. optimizer: An optmizer to save the state of (momentum, etc). + scaler: A scaler to load the state to (mixed precision training). name: If provided, will restore from a checkpoint with the given name. Otherwise, will restore from the latest weights and optimizer state as specified in `paths`. @@ -98,31 +116,39 @@ def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *, `FileNotFoundError`. """ - weights_path, optim_path, checkpoint_path = \ + weights_path, optim_path, scaler_path, checkpoint_path = \ get_checkpoint_paths(checkpoint_type, paths) if name: path_dict = { 'w': checkpoint_path/f'{name}_weights.pyt', 'o': checkpoint_path/f'{name}_optim.pyt', + 's': checkpoint_path/f'{name}_scaler.pyt', } s = 'named' else: - path_dict = { + required_path_dict = { 'w': weights_path, 'o': optim_path } + optional_path_dict = { + 's': scaler_path + } s = 'latest' - num_exist = sum(p.exists() for p in path_dict.values()) - if num_exist == 2: + num_exist = sum(p.exists() for p in required_path_dict.values()) + if num_exist == len(required_path_dict): # Checkpoint exists print(f'Restoring from {s} checkpoint...') - print(f'Loading {s} weights: {path_dict["w"]}') - model.load(path_dict['w']) - print(f'Loading {s} optimizer state: {path_dict["o"]}') - optimizer.load_state_dict(torch.load(path_dict['o'])) + print(f'Loading {s} weights: {required_path_dict["w"]}') + model.load(required_path_dict['w']) + print(f'Loading {s} optimizer state: {required_path_dict["o"]}') + optimizer.load_state_dict(torch.load(required_path_dict['o'])) + + if scaler and optional_path_dict["s"].exists(): + print(f'Loading {s} scaler state: {optional_path_dict["s"]}') + scaler.load_state_dict(torch.load(optional_path_dict['s'])) elif create_if_missing: - save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False) + save_checkpoint(checkpoint_type, paths, model, optimizer, scaler, name=name, is_silent=False) else: raise FileNotFoundError(f'The {s} checkpoint could not be found!') \ No newline at end of file diff --git a/utils/paths.py b/utils/paths.py index 0be37b83..272db1f7 100644 --- a/utils/paths.py +++ b/utils/paths.py @@ -17,6 +17,7 @@ def __init__(self, data_path, voc_id, tts_id): self.voc_checkpoints = self.base/'checkpoints'/f'{voc_id}.wavernn' self.voc_latest_weights = self.voc_checkpoints/'latest_weights.pyt' self.voc_latest_optim = self.voc_checkpoints/'latest_optim.pyt' + self.voc_latest_scaler = self.voc_checkpoints/'latest_scaler.pyt' self.voc_output = self.base/'model_outputs'/f'{voc_id}.wavernn' self.voc_step = self.voc_checkpoints/'step.npy' self.voc_log = self.voc_checkpoints/'log.txt' @@ -25,6 +26,7 @@ def __init__(self, data_path, voc_id, tts_id): self.tts_checkpoints = self.base/'checkpoints'/f'{tts_id}.tacotron' self.tts_latest_weights = self.tts_checkpoints/'latest_weights.pyt' self.tts_latest_optim = self.tts_checkpoints/'latest_optim.pyt' + self.tts_latest_scaler = self.tts_checkpoints/'latest_scaler.pyt' self.tts_output = self.base/'model_outputs'/f'{tts_id}.tacotron' self.tts_step = self.tts_checkpoints/'step.npy' self.tts_log = self.tts_checkpoints/'log.txt' @@ -52,6 +54,10 @@ def get_tts_named_weights(self, name): def get_tts_named_optim(self, name): """Gets the path for the optimizer state in a named tts checkpoint.""" return self.tts_checkpoints/f'{name}_optim.pyt' + + def get_tts_named_scaler(self, name): + """Gets the path for the scaler state in a named tts checkpoint.""" + return self.tts_checkpoints/f'{name}_scaler.pyt' def get_voc_named_weights(self, name): """Gets the path for the weights in a named voc checkpoint.""" @@ -60,5 +66,9 @@ def get_voc_named_weights(self, name): def get_voc_named_optim(self, name): """Gets the path for the optimizer state in a named voc checkpoint.""" return self.voc_checkpoints/f'{name}_optim.pyt' + + def get_voc_named_scaler(self, name): + """Gets the path for the scaler state in a named voc checkpoint.""" + return self.voc_checkpoints/f'{name}_scaler.pyt'