Skip to content

Commit

Permalink
Alignment level topk
Browse files Browse the repository at this point in the history
  • Loading branch information
mmueller00 committed Jan 15, 2025
1 parent 0a2b961 commit 3fa3ca2
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 96 deletions.
188 changes: 127 additions & 61 deletions users/mueller/experiments/ctc_baseline/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,33 @@ def py():
use_flashlight = True
use_greedy = False
epochs = 500
self_training_rounds = 4
self_training_rounds = 1
train_small = True
with_prior = True
empirical_prior = True
aux_loss = True
prior_from_max = False
aux_loss = False
alt_decoder = True
calc_last_pseudo_labels = True
tune_hyperparameters = True
calc_last_pseudo_labels = False
tune_hyperparameters = False
from_scratch = False

use_sum_criterion = False
use_sum_criterion = True
horizontal_prior = True
blank_prior = True
prior_gradient = False
empirical_prior_full_sum = False
prior_from_max_full_sum = False
LM_order = 2
top_k = 1
self_train_subset = None
top_k = 3
self_train_subset = 18000 # 18000

assert (empirical_prior_full_sum and empirical_prior) or not empirical_prior_full_sum

if train_small:
epochs = 50
if self_training_rounds > 0:
self_epochs = 113 # 450, 225, 113, 75, 56, 45
self_epochs = 56 # 450, 225, 113, 75, 56, 45

decoder_hyperparameters = None
if use_greedy:
Expand All @@ -105,7 +111,7 @@ def py():
if with_prior:
decoder_hyperparameters["prior_weight"] = 0.3 # 0.2 if not using emprirical prior

p0 = f"_p{str(decoder_hyperparameters['prior_weight']).replace('.', '')}" + ("-emp" if empirical_prior else "") if with_prior else ""
p0 = f"_p{str(decoder_hyperparameters['prior_weight']).replace('.', '')}" + ("-emp" if empirical_prior else ("-from_max" if prior_from_max else "")) if with_prior else ""
p1 = "sum" if decoder_hyperparameters['log_add'] else "max"
p2 = f"n{decoder_hyperparameters['nbest']}"
p3 = f"b{decoder_hyperparameters['beam_size']}"
Expand All @@ -131,7 +137,7 @@ def py():
else:
str_add = ""

a0 = f"_p{str(alt_decoder_hyperparameters['prior_weight']).replace('.', '')}" + ("-emp" if empirical_prior else "") if with_prior else ""
a0 = f"_p{str(alt_decoder_hyperparameters['prior_weight']).replace('.', '')}" + ("-emp" if empirical_prior else ("-from_max" if prior_from_max else "")) if with_prior else ""
a1 = f"b{alt_decoder_hyperparameters['beam_size']}"
a2 = f"w{str(alt_decoder_hyperparameters['lm_weight']).replace('.', '')}"
a3 = "_tune" if tune_hyperparameters else ""
Expand All @@ -157,29 +163,49 @@ def py():
} if self_training_rounds > 0 else None

for am, lm, prior in [
(1.0, 0.0, 0.55)
(8.0, 0.01, 0.08)
]:
if use_sum_criterion:
training_scales = {
"am": am,
"lm": lm,
"prior": prior
}
if am != 1.0 or lm != 1.0 or prior != 1.0:
scales_not_std = True
config_full_sum = {
"am_scale": am,
"lm_scale": lm,
"prior_scale": prior
}
else:
scales_not_std = False
config_full_sum = {}

if list(training_scales.values()) == [1.0] * len(training_scales):
training_scales = None
if not horizontal_prior:
config_full_sum["horizontal_prior"] = horizontal_prior
if not blank_prior:
config_full_sum["blank_prior"] = blank_prior
if not prior_gradient:
config_full_sum["prior_gradient"] = prior_gradient
if top_k > 0:
config_full_sum["top_k"] = top_k
if empirical_prior_full_sum:
config_full_sum["empirical_prior"] = True
if prior_from_max_full_sum:
config_full_sum["max_prior"] = True

# This is to change the hash when we made chnages in the loss function
config_full_sum["version"] = 1

sum_str = f"-full_sum" + \
(f"_p{str(training_scales['prior']).replace('.', '')}_l{str(training_scales['lm']).replace('.', '')}_a{str(training_scales['am']).replace('.', '')}" if training_scales else "") + \
(f"_p{str(config_full_sum['prior_scale']).replace('.', '')}_l{str(config_full_sum['lm_scale']).replace('.', '')}_a{str(config_full_sum['am_scale']).replace('.', '')}" if scales_not_std else "") + \
(f"_LMorder{LM_order}" if LM_order > 2 else "") + \
(f"_topK{top_k}" if top_k > 0 else "") + \
("_emp" if empirical_prior_full_sum else "") + \
("_max_pr" if not empirical_prior_full_sum and prior_from_max_full_sum else "") + \
("_wo_hor_pr" if not horizontal_prior else "") + \
("_wo_blank_pr" if not blank_prior else "") + \
("_wo_pr_grad" if not prior_gradient else "")

alias_name = f"ctc-baseline" + \
(sum_str if use_sum_criterion else "") + \
(f"-self_training_{self_training_rounds}" + (f"_s{self_train_subset}" if self_train_subset is not None else "") + (f"_e{self_epochs}" if self_epochs != 450 else "") if self_training_rounds > 0 else "") + \
(f"-self_training_{self_training_rounds}" + ("_from_scratch" if from_scratch else "") + (f"_s{self_train_subset}" if self_train_subset is not None else "") + (f"_e{self_epochs}" if self_epochs != 450 else "") if self_training_rounds > 0 else "") + \
(f"-wo_aux_loss" if not aux_loss else "") + \
(f"-ds100h" if train_small else "") + \
f"-{vocab}" + \
Expand All @@ -194,22 +220,20 @@ def py():
model_config = {"enc_conformer_layer": enc_conformer_layer_default, "feature_batch_norm": True},
config_updates = config_updates,
config_updates_self_training = config_updates_self_training,
config_full_sum=config_full_sum if use_sum_criterion else None,
vocab = vocab,
self_training_rounds = self_training_rounds,
train_small = train_small,
with_prior = with_prior,
empirical_prior=empirical_prior,
prior_from_max=prior_from_max,
use_sum_criterion=use_sum_criterion,
aux_loss=aux_loss,
horizontal_prior=horizontal_prior,
blank_prior=blank_prior,
prior_gradient=prior_gradient,
LM_order=LM_order,
top_k=top_k,
training_scales=training_scales if use_sum_criterion else None,
self_train_subset=self_train_subset,
calc_last_pseudo_labels=calc_last_pseudo_labels,
tune_hyperparameters=tune_hyperparameters,
from_scratch=from_scratch,
)


Expand All @@ -231,6 +255,7 @@ def train_exp(
model_config: Optional[Dict[str, Any]] = None,
config_updates: Optional[Dict[str, Any]] = None,
config_updates_self_training: Optional[Dict[str, Any]] = None,
config_full_sum: Optional[Dict[str, Any]] = None,
config_deletes: Optional[Sequence[str]] = None,
post_config_updates: Optional[Dict[str, Any]] = None,
epilog: Sequence[serialization.SerializerObject] = (),
Expand All @@ -244,17 +269,14 @@ def train_exp(
train_small: bool = False,
with_prior: bool = False,
empirical_prior: bool = False,
prior_from_max: bool = False,
use_sum_criterion: bool = False,
aux_loss: bool = False,
horizontal_prior: bool = True,
blank_prior: bool = True,
prior_gradient: bool = True,
LM_order: int = 2,
top_k: int = 0,
training_scales: Optional[Dict[str, float]] = None,
self_train_subset: Optional[int] = None,
calc_last_pseudo_labels: bool = False,
tune_hyperparameters: bool = False,
from_scratch: bool = False,
) -> Optional[ModelWithCheckpoints]:
"""
Train experiment
Expand Down Expand Up @@ -329,10 +351,11 @@ def train_exp(
save_pseudo_labels=(pseudo_labels_ds, train_100_ds) if calc_last_pseudo_labels or self_training_rounds > 0 else None,
calculate_pseudo_label_scores=True, # NOTE: breaks hash
recog_post_proc_funcs=recog_post_proc_funcs,
num_shards_recog=16, # NOTE: breaks hash
# num_shards_recog=16, # NOTE: breaks hash
num_shards_pseudo=64,
# num_shards_prior=64,
is_last=self_training_rounds == 0,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
)

Expand Down Expand Up @@ -361,26 +384,19 @@ def train_exp(

if use_sum_criterion:
train_def = ctc_sum_training
config_self = dict_update_deep(config_self, config_full_sum)
config_self["lm_path"] = get_count_based_n_gram(task.train_dataset.vocab, LM_order)

if not horizontal_prior:
config_self["horizontal_prior"] = horizontal_prior
if not blank_prior:
config_self["blank_prior"] = blank_prior
if training_scales:
config_self["am_scale"] = training_scales["am"]
config_self["lm_scale"] = training_scales["lm"]
config_self["prior_scale"] = training_scales["prior"]
if not prior_gradient:
config_self["prior_gradient"] = prior_gradient
if top_k > 0:
config_self["top_k"] = top_k
if config_self.get("empirical_prior", False):
config_self["empirical_prior"] = emp_prior

# When testing on a smaller subset we only want one gpu
if self_train_subset is not None:
config_self["__num_processes"] = 1
# config_self["learning_rate_piecewise_steps"] = [4_500, 9_000, 10_000]
config_self["learning_rate_piecewise_steps"] = [2_250, 4_500, 5_000]
peak_lr = 1e-4
config_self["learning_rate_piecewise_values"] = [peak_lr * 1.001e-1, peak_lr, peak_lr * 3e-2, peak_lr * 3e-3]
if not aux_loss:
config_self.pop("aux_loss_layers")

Expand All @@ -395,7 +411,10 @@ def train_exp(
config_self["learning_rate_piecewise_values"] = [peak_lr * 1e-1, peak_lr, peak_lr * 3e-2, peak_lr * 3e-3]
config_self["learning_rate_piecewise_steps"] = [20_000] + config_self["learning_rate_piecewise_steps"][1:]

init_checkpoint = model_with_checkpoint[i].get_last_fixed_epoch().checkpoint
if i == 0 and from_scratch:
init_checkpoint = None
else:
init_checkpoint = model_with_checkpoint[i].get_last_fixed_epoch().checkpoint

model_with_checkpoint.append(train(
prefix_self_training,
Expand All @@ -409,7 +428,7 @@ def train_exp(
num_epochs=num_epochs,
gpu_mem=gpu_mem,
num_processes=num_processes,
time_rqmt=time_rqmt if time_rqmt else ((10 if self_train_subset else 156) if use_sum_criterion else 156),
time_rqmt=time_rqmt if time_rqmt else ((4 if self_train_subset else 156) if use_sum_criterion else 156),
))
train_job = model_with_checkpoint[i + 1].get_training_job()
if env_updates:
Expand Down Expand Up @@ -438,6 +457,7 @@ def train_exp(
recog_post_proc_funcs=recog_post_proc_funcs,
num_shards_recog=16, # NOTE: breaks hash
num_shards_prior=64,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
return_summary = True
)
Expand All @@ -457,6 +477,7 @@ def train_exp(
recog_post_proc_funcs=recog_post_proc_funcs,
num_shards_recog=16, # NOTE: breaks hash
num_shards_prior=64,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
return_summary = True
)
Expand All @@ -480,6 +501,7 @@ def train_exp(
num_shards_pseudo=64,
num_shards_prior=64,
is_last=i+1 == self_training_rounds,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
)

Expand Down Expand Up @@ -1004,28 +1026,44 @@ def ctc_training(*, model: Model, data: rf.Tensor, data_spatial_dim: Dim, target
ctc_training: TrainDef[Model]
ctc_training.learning_rate_control_error_measure = "ctc"

def ctc_sum_training(*, model: Model, data: rf.Tensor, data_spatial_dim: Dim, lm_path: tk.Path):
def ctc_sum_training(*, model: Model, data: rf.Tensor, data_spatial_dim: Dim, lm_path: tk.Path, seq_tags: rf.Tensor = None):
"""Function is run within RETURNN."""
from returnn.config import get_global_config
from .sum_criterion import sum_loss, safe_logsumexp

# torch.autograd.set_detect_anomaly(True)

def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor, use_max: bool = False, separate_eos: bool = False) -> torch.Tensor:
lengths = lengths.to(log_probs.device)
assert lengths.size(0) == log_probs.size(0), "Prior calculation batch lengths are not the same (full_sum)!"

mask = torch.arange(log_probs.size(1), device=log_probs.device).expand(log_probs.size(0), -1) < lengths.unsqueeze(1)
mask = torch.where(mask, 0.0, float("-inf"))
mask_bool = torch.arange(log_probs.size(1), device=log_probs.device).expand(log_probs.size(0), -1) < lengths.unsqueeze(1)
mask = torch.where(mask_bool, 0.0, float("-inf"))
mask = mask.unsqueeze(-1).expand(-1, -1, log_probs.size(2))
log_probs = log_probs + mask

sum_frames = lengths.sum()
log_sum_probs = torch.full([log_probs.size(2) + 1,], float("-inf"), device=log_probs.device)
log_sum_probs[1:-1] = safe_logsumexp(safe_logsumexp(log_probs[:,:,1:], dim=0), dim=0) # Sum over batch and time
log_sum_probs[0] = safe_logsumexp(log_probs[:,0,0], dim=0) # BOS prob
log_sum_probs[-1] = safe_logsumexp(safe_logsumexp(log_probs[:,1:,0], dim=0), dim=0) # EOS prob

if use_max:
if separate_eos:
raise NotImplementedError("Separate EOS not implemented for max prior")
else:
argmaxs = log_probs.argmax(dim=2)
argmaxs = argmaxs.flatten()
argmaxs = argmaxs[mask_bool.flatten()]
assert argmaxs.size(0) == sum_frames, f"Prior calculation frame count does not match (max) ({argmaxs.size(0)} != {sum_frames})"
sum_probs = argmaxs.bincount(minlength=log_probs.size(2))
sum_frames += (sum_probs == 0).sum()
sum_probs = torch.where(sum_probs == 0, 1, sum_probs)
log_sum_probs = sum_probs.log()
else:
if separate_eos:
log_sum_probs = torch.full((log_probs.size(2) + 1,), float("-inf"), device=log_probs.device)
log_sum_probs[1:-1] = safe_logsumexp(safe_logsumexp(log_probs[:,:,1:], dim=0), dim=0) # Sum over batch and time
log_sum_probs[0] = safe_logsumexp(log_probs[:,0,0], dim=0) # BOS prob
log_sum_probs[-1] = safe_logsumexp(safe_logsumexp(log_probs[:,1:,0], dim=0), dim=0) # EOS prob
else:
log_sum_probs = safe_logsumexp(safe_logsumexp(log_probs, dim=0), dim=0)

log_mean_probs = log_sum_probs - sum_frames.log()

with torch.no_grad():
Expand All @@ -1047,6 +1085,8 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
horizontal_prior = config.bool("horizontal_prior", True)
blank_prior = config.bool("blank_prior", True)
prior_gradient = config.bool("prior_gradient", True)
empirical_prior = config.typed_value("empirical_prior", None)
max_prior = config.bool("max_prior", False)
top_k = config.int("top_k", 0)
use_prior = prior_scale > 0.0

Expand All @@ -1066,6 +1106,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten

collected_outputs = {}
logits, enc, enc_spatial_dim = model(data, in_spatial_dim=data_spatial_dim, collected_outputs=collected_outputs)

if aux_loss_layers:
for i, layer_idx in enumerate(aux_loss_layers):
if layer_idx > len(model.encoder.layers):
Expand All @@ -1075,9 +1116,14 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
aux_log_probs = model.log_probs_wb_from_logits(aux_logits)
aux_log_probs = aux_log_probs.raw_tensor
if use_prior:
aux_log_prior = _calc_log_prior(aux_log_probs, enc_spatial_dim.dyn_size_ext.raw_tensor)
if not prior_gradient:
aux_log_prior = aux_log_prior.detach()
if empirical_prior is not None:
aux_log_prior = np.loadtxt(empirical_prior, dtype="float32")
aux_log_prior = torch.tensor(aux_log_prior, device=log_probs.device)
assert aux_log_prior.size(0) == log_probs.size(2), "Empirical prior size does not match (full_sum)!"
else:
aux_log_prior = _calc_log_prior(aux_log_probs, enc_spatial_dim.dyn_size_ext.raw_tensor, use_max=max_prior)
if not prior_gradient:
aux_log_prior = aux_log_prior.detach()
else:
aux_log_prior = None
# (B, T, F) -> (T, B, F)
Expand Down Expand Up @@ -1106,13 +1152,32 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
custom_inv_norm_factor=enc_spatial_dim.get_size_tensor(),
use_normalized_loss=use_normalized_loss,
)

fixed_seqs = ["train-other-500/5756-305214-0041/5756-305214-0041"] # MONICA DREW FRESH HOPE FROM HER SON'S WRITINGS THEY WERE FULL OF NOBLE THOUGHTS AND HIGH ASPIRATIONS
print_for_idx = []

seq_tags = seq_tags.raw_tensor
for seq in fixed_seqs:
if seq in seq_tags:
idx = np.where(seq_tags == seq)[0]
print("Found seq", seq, enc_spatial_dim.dyn_size_ext.raw_tensor[idx])
print_for_idx.append(idx[0])

# seq = seq_tags[0]
# idx = np.where(seq_tags == seq)[0]
# print_for_idx.append(idx[0])

log_probs = model.log_probs_wb_from_logits(logits)
log_probs = log_probs.raw_tensor
if use_prior:
log_prior = _calc_log_prior(log_probs, enc_spatial_dim.dyn_size_ext.raw_tensor)
if not prior_gradient:
log_prior = log_prior.detach()
if empirical_prior is not None:
log_prior = np.loadtxt(empirical_prior, dtype="float32")
log_prior = torch.tensor(log_prior, device=log_probs.device)
assert log_prior.size(0) == log_probs.size(2), "Empirical prior size does not match (full_sum)!"
else:
log_prior = _calc_log_prior(log_probs, enc_spatial_dim.dyn_size_ext.raw_tensor, use_max=max_prior)
if not prior_gradient:
log_prior = log_prior.detach()
else:
log_prior = None
# (B, T, F) -> (T, B, F)
Expand All @@ -1132,7 +1197,8 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
blank_idx=model.blank_idx,
eos_idx=model.eos_idx,
unk_idx=1,
device=log_probs.device
device=log_probs.device,
print_best_path_for_idx=print_for_idx
)
loss = rtf.TorchBackend.convert_to_tensor(loss, dims = [batch_dim], dtype = "float32", name=f"full_sum")
loss.mark_as_loss(
Expand Down
Loading

0 comments on commit 3fa3ca2

Please sign in to comment.