Skip to content

Commit

Permalink
decoder simplified/dropout added, remove long text
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-H-Liu committed Sep 27, 2019
1 parent 57cf91d commit 28d8021
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 87 deletions.
11 changes: 6 additions & 5 deletions bin/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def set_model(self):
model_paras = [{'params':self.model.parameters()}]

# Losses
self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none')
self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Note: zero_infinity=False is unstable?

# Plug-ins
Expand All @@ -52,7 +52,7 @@ def set_model(self):
model_paras.append({'params':self.emb_decoder.parameters()})
self.emb_fuse = self.emb_decoder.apply_fuse
if self.emb_fuse:
self.seq_loss = torch.nn.NLLLoss(ignore_index=0, reduction='none')
self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
self.verbose(self.emb_decoder.create_msg())

# Optimizer
Expand Down Expand Up @@ -119,7 +119,7 @@ def exec(self):
att_output = fuse_output if self.emb_fuse else att_output
att_loss = self.seq_loss(att_output.view(b*t,-1),txt.view(-1))
# Sum each uttr and devide by length then mean over batch
att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
# att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
total_loss += att_loss*(1-self.model.ctc_weight)

self.timer.cnt('fw')
Expand Down Expand Up @@ -150,7 +150,8 @@ def exec(self):
self.timer.set()
if self.step > self.max_step:break
n_epochs +=1

self.log.writer.close()

def validate(self):
# Eval mode
self.model.eval()
Expand All @@ -174,7 +175,7 @@ def validate(self):
# Show some example on tensorboard
if i == len(self.dv_set)//2:
for i in range(min(len(txt),self.DEV_N_EXAMPLE)):
if self.step ==0:
if self.step==1:
self.write_log('true_text{}'.format(i),self.tokenizer.decode(txt[i].tolist()))
if att_output is not None:
self.write_log('att_align{}'.format(i),feat_to_fig(att_align[i,0,:,:].cpu().detach()))
Expand Down
3 changes: 3 additions & 0 deletions corpus/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.data import Dataset

OFFICIAL_TXT_SRC = ['librispeech-lm-norm.txt'] # Additional (official) text src provided
REMOVE_TOP_N_TXT = 5000000 # Remove longest N sentence in librispeech-lm-norm.txt
READ_FILE_THREADS = 4 # Default num. of threads used for loading LibriSpeech

def read_text(file):
Expand Down Expand Up @@ -87,6 +88,8 @@ def __init__(self, path, split, tokenizer, bucket_size):

# Read file size and sort dataset by file size (Note: feature len. may be different)
self.text = sorted(self.text, reverse=True, key=lambda x:len(x))
if self.encode_on_fly:
del self.text[:REMOVE_TOP_N_TXT]

def __getitem__(self,index):
if self.bucket_size>1:
Expand Down
111 changes: 37 additions & 74 deletions src/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class ASR(nn.Module):
''' ASR model, including Encoder/Decoder(s)'''
def __init__(self, input_size, vocab_size, ctc_weight, encoder, attention, decoder):
def __init__(self, input_size, vocab_size, ctc_weight, encoder, attention, decoder, emb_drop=0.0):
super(ASR, self).__init__()

# Setup
Expand All @@ -28,13 +28,16 @@ def __init__(self, input_size, vocab_size, ctc_weight, encoder, attention, decod
if self.enable_att:
self.dec_dim = decoder['dim']
self.pre_embed = nn.Embedding(vocab_size, self.dec_dim)
self.embed_drop = nn.Dropout(emb_drop)
self.decoder = Decoder(self.encoder.out_dim+self.dec_dim, vocab_size, **decoder)
self.attention = Attention(self.encoder.out_dim, self.dec_dim, **attention)
query_dim = self.dec_dim*self.decoder.layer
self.attention = Attention(self.encoder.out_dim, query_dim, **attention)

# Init
self.apply(init_weights)
for l in self.decoder.layers:
l.bias_ih = init_gate(l.bias_ih)
for l in range(self.decoder.layer):
bias = getattr(self.decoder.layers,'bias_ih_l{}'.format(l))
bias = init_gate(bias)

def set_state(self, prev_state, prev_attn):
''' Setting up all memory states for beam decoding'''
Expand Down Expand Up @@ -82,14 +85,14 @@ def forward(self, audio_feature, feature_len, decode_step, tf_rate=0.0, teacher=
# Attention based decoding
if self.enable_att:
# Init (init char = <SOS>, reset all rnn state and cell)
self.decoder.init_state(encode_feature)
self.decoder.init_state(bs)
self.attention.reset_mem()
last_char = self.pre_embed(torch.zeros((bs),dtype=torch.long, device=encode_feature.device))
att_seq, output_seq = [], []

# Preprocess data for teacher forcing
if teacher is not None:
teacher = self.pre_embed(teacher)
teacher = self.embed_drop(self.pre_embed(teacher))

# Decode
for t in range(decode_step):
Expand All @@ -112,7 +115,7 @@ def forward(self, audio_feature, feature_len, decode_step, tf_rate=0.0, teacher=
else:
cur_prob = cur_char.softmax(dim=-1)
sampled_char = Categorical(cur_prob).sample()
last_char = self.pre_embed(sampled_char)
last_char = self.embed_drop(self.pre_embed(sampled_char))
else:
# Inference stage
if (emb_decoder is not None) and emb_decoder.apply_fuse:
Expand All @@ -137,100 +140,60 @@ def forward(self, audio_feature, feature_len, decode_step, tf_rate=0.0, teacher=
class Decoder(nn.Module):
''' Decoder (a.k.a. Speller in LAS) '''
# ToDo: More elegant way to implement decoder
def __init__(self, input_dim, vocab_size, module, dim, layer, dropout, layer_norm):
def __init__(self, input_dim, vocab_size, module, dim, layer, dropout):
super(Decoder, self).__init__()
self.in_dim = input_dim
self.layer = layer
self.dim = dim
self.dropout = dropout
self.layer_norm = layer_norm

# Init
self.module = module+'Cell'
self.state_list = []
self.enable_cell = False
if module == 'LSTM':
self.enable_cell = True
self.cell_list = []
elif module not in ['LSTM','GRU']:
raise NotImplementedError
assert module in ['LSTM','GRU'], NotImplementedError
self.hidden_state = None
self.enable_cell = module=='LSTM'

# Modules
module_list = []
in_dim = input_dim
for i in range(layer):
module_list.append(getattr(nn,self.module)(in_dim,dim))
in_dim = dim

# Regularization
if self.layer_norm:
self.ln_list = nn.ModuleList([nn.LayerNorm(dim) for l in range(layer)])
if self.dropout > 0:
self.dp = nn.Dropout(self.dropout)

self.layers = nn.ModuleList(module_list)
self.layers = getattr(nn,module)(input_dim,dim, num_layers=layer, dropout=dropout, batch_first=True)
self.char_trans = nn.Linear(dim,vocab_size)
self.final_dropout = nn.Dropout(dropout)


def init_state(self, context):
def init_state(self, bs):
''' Set all hidden states to zeros '''
self.state_list = [torch.zeros((context.shape[0],self.dim),device=context.device)]*self.layer
device = next(self.parameters()).device
if self.enable_cell:
self.cell_list = [torch.zeros((context.shape[0],self.dim),device=context.device)]*self.layer
return self.state_list, self.cell_list
return self.state_list
self.hidden_state = (torch.zeros((self.layer,bs,self.dim),device=device),
torch.zeros((self.layer,bs,self.dim),device=device))
else:
self.hidden_state = torch.zeros((self.layer,bs,self.dim),device=device)
return self.get_state()

def set_state(self, state_list):
def set_state(self, hidden_state):
''' Set all hidden states/cells, for decoding purpose'''
device = next(self.parameters()).device
if self.enable_cell:
state_list, cell_list = state_list
self.cell_list = [c.to(device) for c in cell_list]
self.state_list = [s.to(device) for s in state_list]
self.hidden_state = (hidden_state[0].to(device),hidden_state[1].to(device))
else:
self.hidden_state = hidden_state.to(device)

def get_state(self):
''' Return all hidden states/cells, for decoding purpose'''
state_list = [s.cpu() for s in self.state_list]
if self.enable_cell:
cell_list = [c.cpu() for c in self.cell_list]
return state_list, cell_list
return state_list

def get_query(self):
''' Return state of layer 0 as query for attention '''
return self.state_list[0]

def _get_layer_state(self, layer_idx):
''' Get hidden state of specified layer '''
if self.enable_cell:
return (self.state_list[layer_idx],self.cell_list[layer_idx])
return (self.hidden_state[0].cpu(),self.hidden_state[1].cpu())
else:
return self.state_list[layer_idx]
return self.hidden_state.cpu()

def _store_layer_state(self, layer_idx, state):
''' Replace hidden state of specified layer '''
def get_query(self):
''' Return state of all layers as query for attention '''
if self.enable_cell:
self.state_list[layer_idx] = state[0]
self.cell_list[layer_idx] = state[1]
return state[0]
return self.hidden_state[0].transpose(0,1).reshape(-1,self.dim*self.layer)
else:
self.state_list[layer_idx] = state
return state
return self.hidden_state.transpose(0,1).reshape(-1,self.dim*self.layer)

def forward(self, x):
''' Manually forward through all layers '''
for i, layers in enumerate(self.layers):
state = self._get_layer_state(i)
x = layers(x,state)
x = self._store_layer_state(i,x)

if self.layer_norm:
x = self.ln_list[i](x)
if self.dropout > 0:
x = self.dp(x)

char = self.char_trans(x)

''' Decode and transform into vocab '''
x, self.hidden_state = self.layers(x.unsqueeze(1),self.hidden_state)
x = x.squeeze(1)
char = self.char_trans(self.final_dropout(x))
return char, x


Expand Down
8 changes: 3 additions & 5 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

DEV_N_JOBS = 2 # Number of threads used for dev set
HALF_BATCHSIZE_AUDIO_LEN = 800 # Batch size will be halfed if the longest wavefile surpasses threshold
# Note: Bucketing may cause random sampling to be biased (less sampled for those length > HALF_BATCHSIZE_AUDIO_LEN )
HALF_BATCHSIZE_TEXT_LEN = 150
Expand Down Expand Up @@ -131,12 +130,11 @@ def load_dataset(n_jobs, use_gpu, pin_memory, ascending, corpus, audio, text):
# Shuffle/drop applied to training set only
shuffle = (mode=='train' and not ascending)
drop_last = shuffle
num_workers = max(0,n_jobs-DEV_N_JOBS) if mode=='train' else DEV_N_JOBS
# Create data loader
tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=shuffle, drop_last=drop_last, collate_fn=collect_tr,
num_workers=num_workers, pin_memory=use_gpu)
num_workers=n_jobs, pin_memory=use_gpu)
dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv,
num_workers=DEV_N_JOBS, pin_memory=pin_memory)
num_workers=n_jobs, pin_memory=pin_memory)
# Messages to show
data_msg.append('I/O spec. | Audio feature = {}\t| feature dim = {}\t| Token type = {}\t| Vocab size = {}'\
.format(audio['feat_type'],feat_dim,tokenizer.token_type,tokenizer.vocab_size))
Expand All @@ -156,7 +154,7 @@ def load_textset(n_jobs, use_gpu, pin_memory, corpus, text):
tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr,
num_workers=0, pin_memory=use_gpu)
dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv,
num_workers=DEV_N_JOBS, pin_memory=pin_memory)
num_workers=0, pin_memory=pin_memory)

# Messages to show
data_msg.append('I/O spec. | Token type = {}\t| Vocab size = {}'\
Expand Down
2 changes: 1 addition & 1 deletion src/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, audio_feature, feature_len):
assert audio_feature.shape[0]==1, "Batchsize == 1 is required for beam search"
batch_size = audio_feature.shape[0]
device = audio_feature.device
dec_state = self.asr.decoder.init_state(audio_feature.cpu()) # Init zero states
dec_state = self.asr.decoder.init_state(batch_size) # Init zero states
max_output_len = int(np.ceil(feature_len.cpu().item()*self.max_len_ratio)) # Max output len set w/ hyper param.
min_output_len = int(np.ceil(feature_len.cpu().item()*self.min_len_ratio)) # Min output len set w/ hyper param.
store_att = self.asr.attention.mode == 'loc' # Store attention map if location-aware
Expand Down
2 changes: 1 addition & 1 deletion src/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
default_hparas = {
'GRAD_CLIP': 5.0, # Grad. clip threshold
'PROGRESS_STEP': 100, # Std. output refresh freq.
'DEV_STEP_RATIO': 1.2, # Decode steps for objective validation (step = ratio*input_len)
'DEV_STEP_RATIO': 1.2, # Decode steps for objective validation (step = ratio*input_txt_len)
'DEV_N_EXAMPLE': 4, # Number of examples (alignment/text) to show in tensorboard
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
}
2 changes: 1 addition & 1 deletion src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def feat_to_fig(feat):
return torch.FloatTensor(data),"HWC"

def _save_canvas(data, meta=None):
fig, ax = plt.subplots(figsize=(16, 10))
fig, ax = plt.subplots(figsize=(16, 8))
if meta is None:
ax.imshow(data, aspect="auto", origin="lower")
else:
Expand Down

0 comments on commit 28d8021

Please sign in to comment.