Skip to content

Commit

Permalink
add Conv2dGT & CIF
Browse files Browse the repository at this point in the history
  • Loading branch information
vectominist committed Dec 6, 2022
1 parent 25dbbf5 commit df5d672
Show file tree
Hide file tree
Showing 10 changed files with 709 additions and 20 deletions.
20 changes: 13 additions & 7 deletions egs/timit/config/ctc_train_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ model:
extractor:
name: fbank
train: false
# cnn:
# hid_dim: 64
cnn:
hid_dim: 64
out_dim: 256
# cif:
# hid_dim: 256
# threshold: 1.0
# downsample: 8.0
# calc_weight: sim-l2
# window_size: 5
# cnngt:
# hid_dim: 32
# out_dim: 256
cif:
hid_dim: 256
threshold: 1.0
downsample: 8.0
# downsample: 4.0
encoder:
module: conformer
d_model: 256
Expand Down Expand Up @@ -68,7 +74,7 @@ trainer:
gpus: 1
precision: 16
logger: wandb
default_root_dir: model/ctc_conf-cif_timit_phone_9
default_root_dir: model/ctc_conf-cnngt_timit_phone_1
deterministic: false

wandb:
Expand Down
6 changes: 4 additions & 2 deletions egs/timit/get_sample.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

. ./path.sh || exit 1;

name=ctc_conf-cif-inn-w5_timit_phone_1

get_sample.py \
/data/sls/r/u/hengjui/home/scratch/dataset/miniasr_data/timit_phone/dev/data_list_sorted.json \
model/ctc_conf-cif_timit_phone_7/epoch=129-step=18720.ckpt \
model/ctc_conf-cif_timit_phone_7/samples \
model/$name/epoch=304-step=43920.ckpt \
model/$name/samples \
-i 0
7 changes: 7 additions & 0 deletions miniasr/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
from collections import defaultdict
from functools import partial

import torch
Expand Down Expand Up @@ -34,6 +35,7 @@ def audio_collate_fn(data_list, mode="train"):

waves, wave_len = [], []
texts, text_len = [], []
other = defaultdict(list)
for data in data_list:
# Load raw waveform
waves.append(load_waveform(data["file"], TARGET_SR))
Expand All @@ -46,6 +48,10 @@ def audio_collate_fn(data_list, mode="train"):
texts.append(data.get("text", ""))
text_len.append(0)

for k, v in data.items():
if k not in {"file", "text"}:
other[k].append(v)

waves = pad_sequence(waves, batch_first=True)
wave_len = torch.LongTensor(wave_len)

Expand All @@ -59,6 +65,7 @@ def audio_collate_fn(data_list, mode="train"):
"wave_len": wave_len,
"text": texts,
"text_len": text_len,
"other": other,
}


Expand Down
7 changes: 7 additions & 0 deletions miniasr/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __init__(self, paths, tokenizer: _BaseTextEncoder, mode="train", max_len=160

self.data_list = [d for d in data_list if len(d.get(trans_key, [0])) > 0]

for key in ["align_phone", "align_word"]:
if key in self.data_list[0]:
for i, d in enumerate(self.data_list):
self.data_list[i][key] = [
int(float(t) / 160) for (_, t, _) in d[key][:-1]
]

logging.info(
f"{len(self.data_list)} audio files found " f"(mode = {self.mode})"
)
Expand Down
6 changes: 3 additions & 3 deletions miniasr/model/base_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def extract_features(self, wave, wave_len):

return feat, feat_len

def forward(self, wave, wave_len):
def forward(self, wave, wave_len, **kwargs):
"""Forward function to compute logits."""
raise NotImplementedError
# Should return logits, enc_len, feat, feat_len
Expand All @@ -146,7 +146,7 @@ def training_step(self, batch, batch_idx):
wave_len, text_len = batch["wave_len"], batch["text_len"]

# Compute logits
logits, enc_len, feat, feat_len, other = self(wave, wave_len)
logits, enc_len, feat, feat_len, other = self(**batch)

# Compute loss
loss = self.cal_loss(logits, enc_len, feat, feat_len, text, text_len)
Expand Down Expand Up @@ -187,7 +187,7 @@ def validation_step(self, batch, batch_idx):

with torch.no_grad():
# Compute logits
logits, enc_len, feat, feat_len, other = self(wave, wave_len)
logits, enc_len, feat, feat_len, other = self(**batch)

# Compute loss
loss = self.cal_loss(logits, enc_len, feat, feat_len, text, text_len)
Expand Down
20 changes: 15 additions & 5 deletions miniasr/model/ctc_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from miniasr.module import (
DownsampleCIF,
DownsampleConv2d,
DownsampleConv2dGT,
RNNEncoder,
TransformerEncoder,
)
Expand All @@ -30,13 +31,20 @@ def __init__(self, tokenizer, args: EasyDict):

# Conv Layer
hid_dim = self.in_dim
self.cnn = None
self.cnn, self.cif = None, None
self.conv_type = ""
if self.args.model.get("cnn", None) is not None:
self.conv_type = "cnn"
self.cnn = DownsampleConv2d(self.in_dim, **args.model.cnn)
hid_dim = self.cnn.out_dim
elif self.args.model.get("cif", None) is not None:
self.conv_type = "cif"
self.cif = DownsampleCIF(self.in_dim, **args.model.cif)
hid_dim = self.cif.out_dim
elif self.args.model.get("cnngt", None) is not None:
self.conv_type = "cnngt"
self.cnn = DownsampleConv2dGT(self.in_dim, **args.model.cnngt)
hid_dim = self.cnn.out_dim

# Encoder Layer
if self.args.model.encoder.module in {"RNN", "GRU", "LSTM"}:
Expand Down Expand Up @@ -128,14 +136,14 @@ def setup_flashlight(self):
f"Word score {self.args.decode.word_score}"
)

def forward(self, wave, wave_len):
def forward(self, wave, wave_len, **kwargs):
"""
Forward function to compute logits.
Input:
wave [list]: list of waveform files
wave_len [long tensor]: waveform lengths
Output:
logtis [float tensor]: Batch x Time x Vocabs
logits [float tensor]: Batch x Time x Vocabs
enc_len [long tensor]: encoded length (logits' lengths)
feat [float tensor]: extracted features
feat_len [long tensor]: length of extracted features
Expand All @@ -147,14 +155,16 @@ def forward(self, wave, wave_len):
feat, feat_len = self.extract_features(wave, wave_len)

# CNN/CIF features
if self.cnn:
if self.conv_type == "cnn":
feat, feat_len = self.cnn(feat, feat_len)
elif self.cif:
if self.conv_type == "cif":
res = self.cif(feat, feat_len)
feat, feat_len = res["x"], res["x_len"]
other["quantity_loss"] = res["loss"]
other["cif_prob"] = res["prob"]
other["cif_indices"] = res["indices"]
if self.conv_type == "cnngt":
feat, feat_len = self.cnn(feat, feat_len, kwargs["other"]["align_phone"])

# Encode features
if self.args.model.encoder.module in {"RNN", "GRU", "LSTM"}:
Expand Down
2 changes: 1 addition & 1 deletion miniasr/module/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .cif import DownsampleCIF
from .cnn import DownsampleConv2d
from .cnn import DownsampleConv2d, DownsampleConv2dGT
from .rnn import RNNEncoder
from .scheduler import create_lambda_lr_warmup
from .transformer.masking import len_to_mask
Expand Down
Loading

0 comments on commit df5d672

Please sign in to comment.