Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Support to choose unigram and bigram for P in LF-MMI training. #218

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from snowfall.models.contextnet import ContextNet
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import create_unigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols

def nbest_decoding(lats: k2.Fsa, num_paths: int):
Expand Down Expand Up @@ -401,6 +402,15 @@ def get_parser():
type=str2bool,
default=True,
help='When enabled, it uses vgg style network for subsampling')

parser.add_argument(
'--use-unigram-lm',
type=str2bool,
default=False,
help='True to use unigram LM for P. False to use bigram LM for P. '
'This is used only for checkpoint-loading.'
)

return parser


Expand All @@ -423,7 +433,10 @@ def main():

output_beam_size = args.output_beam_size

exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
if args.use_unigram_lm:
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram')
else:
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

logging.info(f'output_beam_size: {output_beam_size}')
Expand All @@ -434,7 +447,12 @@ def main():
phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

phone_ids = get_phone_symbols(phone_symbol_table)
P = create_bigram_phone_lm(phone_ids)
if args.use_unigram_lm:
logging.info('Use unigram LM for P')
P = create_unigram_phone_lm(phone_ids)
else:
logging.info('Use bigram LM for P')
P = create_bigram_phone_lm(phone_ids)

phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
Expand Down
23 changes: 21 additions & 2 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change
from snowfall.training.mmi_graph import MmiTrainingGraphCompiler
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import create_unigram_phone_lm


def get_objf(batch: Dict,
Expand Down Expand Up @@ -461,6 +462,14 @@ def get_parser():
'so that they can be simply loaded with torch.jit.load(). '
'-1 disables this option.'
)

parser.add_argument(
'--use-unigram-lm',
type=str2bool,
default=False,
help='True to use unigram LM for P. False to use bigram LM for P.'
)

return parser


Expand All @@ -487,7 +496,10 @@ def run(rank, world_size, args):
fix_random_seed(42)
setup_dist(rank, world_size, args.master_port)

exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
if args.use_unigram_lm:
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram')
else:
exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
setup_logger(f'{exp_dir}/log/log-train-{rank}')
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
Expand All @@ -507,7 +519,14 @@ def run(rank, world_size, args):
device=device,
)
phone_ids = lexicon.phone_symbols()
P = create_bigram_phone_lm(phone_ids)

if args.use_unigram_lm:
logging.info('Use unigram LM for P')
P = create_unigram_phone_lm(phone_ids)
else:
logging.info('Use bigram LM for P')
P = create_bigram_phone_lm(phone_ids)

P.scores = torch.zeros_like(P.scores)
P = P.to(device)

Expand Down
28 changes: 27 additions & 1 deletion snowfall/training/mmi_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,33 @@ def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa:
rules += f'{i} {j} {phones[j-1]} 0.0\n'
rules += f'{i} {final_state} -1 0.0\n'
rules += f'{final_state}'
return k2.Fsa.from_str(rules)
ans = k2.Fsa.from_str(rules)
return k2.arc_sort(ans)

def create_unigram_phone_lm(phones: List[int]) -> k2.Fsa:
'''Create a unigram phone LM.
The resulting FSA (P) has two states: a start state and a
final state. For each phone, there is a corresponding self-loop
at the start state.

Caution:
blank is not a phone.

Args:
A list of phone IDs.

Returns:
An FSA representing the unigram phone LM.
'''
assert 0 not in phones

rules = '0 1 -1 0.0\n'
for i in phones:
rules += f'0 0 {i} 0.0\n'
rules += '1\n'

ans = k2.Fsa.from_str(rules)
return k2.arc_sort(ans)


class MmiTrainingGraphCompiler(object):
Expand Down