Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mixed precision training #229

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length
voc_seq_len = hop_length * 5 # must be a multiple of hop_length
voc_clip_grad_norm = 4 # set to None if no gradient clipping needed
voc_use_mixed_precision = True # Enable mixed precision

# Generating / Synthesizing
voc_gen_batched = True # very fast (realtime+) single utterance batched generation
Expand Down Expand Up @@ -91,7 +92,7 @@
tts_clip_grad_norm = 1.0 # clips the gradient norm to prevent explosion - set to None if not needed
tts_checkpoint_every = 2_000 # checkpoints the model every X steps
# TODO: tts_phoneme_prob = 0.0 # [0 <-> 1] probability for feeding model phonemes vrs graphemes

tts_use_mixed_precision = False # Enable mixed precision

# ------------------------------------------------------------------------------------------------------------------#

67 changes: 42 additions & 25 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def main():
stop_threshold=hp.tts_stop_threshold).to(device)

optimizer = optim.Adam(model.parameters())
restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True)
scaler = torch.cuda.amp.GradScaler() if hp.tts_use_mixed_precision and device.type == 'cuda' else None

print('Using mixed precision:', scaler is not None)

restore_checkpoint('tts', paths, model, optimizer, scaler, create_if_missing=True)

if not force_gta:
for i, session in enumerate(hp.tts_schedule):
Expand Down Expand Up @@ -95,7 +99,7 @@ def main():
('Outputs/Step (r)', model.r)])

train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)
tts_train_loop(paths, model, optimizer, scaler, train_set, lr, training_steps, attn_example)

print('Training Complete.')
print('To continue training increase tts_total_steps in hparams.py or use --force_train\n')
Expand All @@ -109,7 +113,7 @@ def main():
print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')


def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
def tts_train_loop(paths: Paths, model: Tacotron, optimizer, scaler, train_set, lr, train_steps, attn_example):
device = next(model.parameters()).device # use same device as model parameters

for g in optimizer.param_groups: g['lr'] = lr
Expand All @@ -126,26 +130,40 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai
for i, (x, m, ids, _) in enumerate(train_set, 1):

x, m = x.to(device), m.to(device)

# Parallelize model onto GPUS using workaround due to python bug
if device.type == 'cuda' and torch.cuda.device_count() > 1:
m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
else:
m1_hat, m2_hat, attention = model(x, m)

m1_loss = F.l1_loss(m1_hat, m)
m2_loss = F.l1_loss(m2_hat, m)

loss = m1_loss + m2_loss


optimizer.zero_grad()
loss.backward()
if hp.tts_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')

optimizer.step()

with torch.cuda.amp.autocast(enabled=scaler is not None):
# Parallelize model onto GPUS using workaround due to python bug
if device.type == 'cuda' and torch.cuda.device_count() > 1:
m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
else:
m1_hat, m2_hat, attention = model(x, m)

m1_loss = F.l1_loss(m1_hat, m)
m2_loss = F.l1_loss(m2_hat, m)

loss = m1_loss + m2_loss

if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)

if hp.tts_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')

scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if hp.tts_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')

optimizer.step()

running_loss += loss.item()
avg_loss = running_loss / i
Expand All @@ -157,20 +175,19 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai

if step % hp.tts_checkpoint_every == 0:
ckpt_name = f'taco_step{k}K'
save_checkpoint('tts', paths, model, optimizer,
save_checkpoint('tts', paths, model, optimizer, scaler,
name=ckpt_name, is_silent=True)

if attn_example in ids:
idx = ids.index(attn_example)
save_attention(np_now(attention[idx][:, :160]), paths.tts_attention/f'{step}')
save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot/f'{step}', 600)

msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | '
stream(msg)

# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
save_checkpoint('tts', paths, model, optimizer, is_silent=True)
save_checkpoint('tts', paths, model, optimizer, scaler, is_silent=True)
model.log(paths.tts_log, msg)
print(' ')

Expand Down
63 changes: 40 additions & 23 deletions train_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def main():
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

optimizer = optim.Adam(voc_model.parameters())
restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True)
scaler = torch.cuda.amp.GradScaler() if hp.voc_use_mixed_precision and device.type == 'cuda' else None

print('Using mixed precision:', scaler is not None)

restore_checkpoint('voc', paths, voc_model, optimizer, scaler, create_if_missing=True)

train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)

Expand All @@ -82,13 +86,13 @@ def main():

loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps)
voc_train_loop(paths, voc_model, loss_func, optimizer, scaler, train_set, test_set, lr, total_steps)

print('Training Complete.')
print('To continue training increase voc_total_steps in hparams.py or use --force_train')


def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps):
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, scaler, train_set, test_set, lr, total_steps):
# Use same device as model parameters
device = next(model.parameters()).device

Expand All @@ -105,30 +109,43 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set
for i, (x, y, m) in enumerate(train_set, 1):
x, m, y = x.to(device), m.to(device), y.to(device)

# Parallelize model onto GPUS using workaround due to python bug
if device.type == 'cuda' and torch.cuda.device_count() > 1:
y_hat = data_parallel_workaround(model, x, m)
else:
y_hat = model(x, m)
optimizer.zero_grad()

if model.mode == 'RAW':
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
with torch.cuda.amp.autocast(enabled=scaler is not None):
# Parallelize model onto GPUS using workaround due to python bug
if device.type == 'cuda' and torch.cuda.device_count() > 1:
y_hat = data_parallel_workaround(model, x, m)
else:
y_hat = model(x, m)

elif model.mode == 'MOL':
y = y.float()
if model.mode == 'RAW':
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

y = y.unsqueeze(-1)
elif model.mode == 'MOL':
y = y.float()

y = y.unsqueeze(-1)

loss = loss_func(y_hat, y)
loss = loss_func(y_hat, y)

optimizer.zero_grad()
loss.backward()
if hp.voc_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')
optimizer.step()
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)

if hp.voc_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')

scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if hp.voc_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm)
if np.isnan(grad_norm):
print('grad_norm was NaN!')
optimizer.step()

running_loss += loss.item()
avg_loss = running_loss / i
Expand All @@ -142,15 +159,15 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set
gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
hp.voc_target, hp.voc_overlap, paths.voc_output)
ckpt_name = f'wave_step{k}K'
save_checkpoint('voc', paths, model, optimizer,
save_checkpoint('voc', paths, model, optimizer, scaler,
name=ckpt_name, is_silent=True)

msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
stream(msg)

# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
save_checkpoint('voc', paths, model, optimizer, is_silent=True)
save_checkpoint('voc', paths, model, optimizer, scaler, is_silent=True)
model.log(paths.voc_log, msg)
print(' ')

Expand Down
76 changes: 51 additions & 25 deletions utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,85 @@ def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
if checkpoint_type is 'tts':
weights_path = paths.tts_latest_weights
optim_path = paths.tts_latest_optim
scaler_path = paths.tts_latest_scaler
checkpoint_path = paths.tts_checkpoints
elif checkpoint_type is 'voc':
weights_path = paths.voc_latest_weights
optim_path = paths.voc_latest_optim
scaler_path = paths.voc_latest_scaler
checkpoint_path = paths.voc_checkpoints
else:
raise NotImplementedError

return weights_path, optim_path, checkpoint_path
return weights_path, optim_path, scaler_path, checkpoint_path


def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, scaler, *,
name=None, is_silent=False):
"""Saves the training session to disk.

Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
scaler: A scaler to save the state of (mixed precision training).
name: If provided, will name to a checkpoint with the given name. Note
that regardless of whether this is provided or not, this function
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
def helper(path_dict, is_named):
def helper(required_path_dict, optional_path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())
num_exist = sum(p.exists() for p in required_path_dict.values())

if num_exist not in (0,2):
if num_exist not in (0,len(required_path_dict)):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')

if num_exist == 0:
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
for p in required_path_dict.values():
p.parent.mkdir(parents=True, exist_ok=True)
else:
if not is_silent: print(f'Saving to existing {s} checkpoint...')

if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])
if not is_silent: print(f'Saving {s} weights: {required_path_dict["w"]}')
model.save(required_path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {required_path_dict["o"]}')
torch.save(optimizer.state_dict(), required_path_dict['o'])

for p in optional_path_dict.values():
if not p.exists():
p.parent.mkdir(parents=True, exist_ok=True)

if scaler:
if not is_silent: print(f'Saving {s} scaler state: {optional_path_dict["s"]}')
torch.save(scaler.state_dict(), optional_path_dict['s'])

weights_path, optim_path, checkpoint_path = \
weights_path, optim_path, scaler_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)

latest_paths = {'w': weights_path, 'o': optim_path}
helper(latest_paths, False)
latest_required_paths = {'w': weights_path, 'o': optim_path}
latest_optional_paths = {'s': scaler_path}

helper(latest_required_paths, latest_optional_paths, False)

if name:
named_paths = {
named_required_paths = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
}
helper(named_paths, True)

named_optional_paths = {
's': checkpoint_path/f'{name}_scaler.pyt',
}
helper(named_required_paths, named_optional_paths, True)


def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, scaler, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.

Expand All @@ -88,6 +105,7 @@ def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
scaler: A scaler to load the state to (mixed precision training).
name: If provided, will restore from a checkpoint with the given name.
Otherwise, will restore from the latest weights and optimizer state
as specified in `paths`.
Expand All @@ -98,31 +116,39 @@ def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
`FileNotFoundError`.
"""

weights_path, optim_path, checkpoint_path = \
weights_path, optim_path, scaler_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)

if name:
path_dict = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
's': checkpoint_path/f'{name}_scaler.pyt',
}
s = 'named'
else:
path_dict = {
required_path_dict = {
'w': weights_path,
'o': optim_path
}
optional_path_dict = {
's': scaler_path
}
s = 'latest'

num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
num_exist = sum(p.exists() for p in required_path_dict.values())
if num_exist == len(required_path_dict):
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load(path_dict['o']))
print(f'Loading {s} weights: {required_path_dict["w"]}')
model.load(required_path_dict['w'])
print(f'Loading {s} optimizer state: {required_path_dict["o"]}')
optimizer.load_state_dict(torch.load(required_path_dict['o']))

if scaler and optional_path_dict["s"].exists():
print(f'Loading {s} scaler state: {optional_path_dict["s"]}')
scaler.load_state_dict(torch.load(optional_path_dict['s']))
elif create_if_missing:
save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
save_checkpoint(checkpoint_type, paths, model, optimizer, scaler, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')
Loading