Skip to content

Commit a016617

Browse files
authored
Merge pull request #165 from alibaba-damo-academy/dev_cmz
punctuation:add training code, support largedataset
2 parents be7230f + ee06cb9 commit a016617

File tree

16 files changed

+1043
-46
lines changed

16 files changed

+1043
-46
lines changed

funasr/bin/punc_train.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
import os
3+
from funasr.tasks.punctuation import PunctuationTask
4+
5+
6+
def parse_args():
7+
parser = PunctuationTask.get_parser()
8+
parser.add_argument(
9+
"--gpu_id",
10+
type=int,
11+
default=0,
12+
help="local gpu id.",
13+
)
14+
parser.add_argument(
15+
"--punc_list",
16+
type=str,
17+
default=None,
18+
help="Punctuation list",
19+
)
20+
args = parser.parse_args()
21+
return args
22+
23+
24+
def main(args=None, cmd=None):
25+
"""
26+
punc training.
27+
"""
28+
PunctuationTask.main(args=args, cmd=cmd)
29+
30+
31+
if __name__ == "__main__":
32+
args = parse_args()
33+
34+
# setup local gpu_id
35+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
36+
37+
# DDP settings
38+
if args.ngpu > 1:
39+
args.distributed = True
40+
else:
41+
args.distributed = False
42+
43+
main(args=args)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python3
2+
import os
3+
from funasr.tasks.punctuation import PunctuationTask
4+
5+
6+
def parse_args():
7+
parser = PunctuationTask.get_parser()
8+
parser.add_argument(
9+
"--gpu_id",
10+
type=int,
11+
default=0,
12+
help="local gpu id.",
13+
)
14+
parser.add_argument(
15+
"--punc_list",
16+
type=str,
17+
default=None,
18+
help="Punctuation list",
19+
)
20+
args = parser.parse_args()
21+
return args
22+
23+
24+
def main(args=None, cmd=None):
25+
"""
26+
punc training.
27+
"""
28+
PunctuationTask.main(args=args, cmd=cmd)
29+
30+
31+
if __name__ == "__main__":
32+
args = parse_args()
33+
34+
# setup local gpu_id
35+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
36+
37+
# DDP settings
38+
if args.ngpu > 1:
39+
args.distributed = True
40+
else:
41+
args.distributed = False
42+
assert args.num_worker_count == 1
43+
44+
main(args=args)

funasr/datasets/large_datasets/build_dataloader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,20 @@ def load_seg_dict(seg_dict_file):
3434
return seg_dict
3535

3636
class ArkDataLoader(AbsIterFactory):
37-
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
37+
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"):
3838
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
3939
if seg_dict_file is not None:
4040
seg_dict = load_seg_dict(seg_dict_file)
4141
else:
4242
seg_dict = None
43+
if punc_dict_file is not None:
44+
punc_dict = read_symbol_table(punc_dict_file)
45+
else:
46+
punc_dict = None
4347
self.dataset_conf = dataset_conf
4448
logging.info("dataloader config: {}".format(self.dataset_conf))
4549
batch_mode = self.dataset_conf.get("batch_mode", "padding")
46-
self.dataset = Dataset(data_list, symbol_table, seg_dict,
50+
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
4751
self.dataset_conf, mode=mode, batch_mode=batch_mode)
4852

4953
def build_iter(self, epoch, shuffle=True):

funasr/datasets/large_datasets/dataset.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,17 @@ def __iter__(self):
127127
sample_dict["key"] = key
128128
else:
129129
text = item
130-
sample_dict[data_name] = text.strip().split()[1:]
130+
segs = text.strip().split()
131+
sample_dict[data_name] = segs[1:]
132+
if "key" not in sample_dict:
133+
sample_dict["key"] = segs[0]
131134
yield sample_dict
132135

133136
self.close_reader(reader_list)
134137

135138

136139
def len_fn_example(data):
137-
return len(data)
140+
return 1
138141

139142

140143
def len_fn_token(data):
@@ -148,6 +151,7 @@ def len_fn_token(data):
148151
def Dataset(data_list_file,
149152
dict,
150153
seg_dict,
154+
punc_dict,
151155
conf,
152156
mode="train",
153157
batch_mode="padding"):
@@ -162,7 +166,7 @@ def Dataset(data_list_file,
162166
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
163167

164168
if "text" in data_names:
165-
vocab = {'vocab': dict, 'seg_dict': seg_dict}
169+
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
166170
tokenize_fn = partial(tokenize, **vocab)
167171
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
168172

@@ -191,6 +195,10 @@ def Dataset(data_list_file,
191195
sort_size=sort_size,
192196
batch_mode=batch_mode)
193197

194-
dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
198+
int_pad_value = conf.get("int_pad_value", -1)
199+
float_pad_value = conf.get("float_pad_value", 0.0)
200+
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
201+
padding_fn = partial(padding, **padding_conf)
202+
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
195203

196204
return dataset

funasr/datasets/large_datasets/utils/padding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
def padding(data, float_pad_value=0.0, int_pad_value=-1):
77
assert isinstance(data, list)
88
assert "key" in data[0]
9-
assert "speech" in data[0]
10-
assert "text" in data[0]
11-
9+
assert "speech" in data[0] or "text" in data[0]
10+
1211
keys = [x["key"] for x in data]
1312

1413
batch = {}

funasr/datasets/large_datasets/utils/tokenize.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,43 @@ def seg_tokenize(txt, seg_dict):
3131

3232
def tokenize(data,
3333
vocab=None,
34-
seg_dict=None):
34+
seg_dict=None,
35+
punc_dict=None):
3536
assert "text" in data
3637
assert isinstance(vocab, dict)
3738
text = data["text"]
3839
token = []
40+
vad = -2
3941

4042
if seg_dict is not None:
4143
assert isinstance(seg_dict, dict)
4244
txt = forward_segment("".join(text).lower(), seg_dict)
4345
text = seg_tokenize(txt, seg_dict)
44-
45-
for x in text:
46-
if x in vocab:
46+
47+
length = len(text)
48+
for i in range(length):
49+
x = text[i]
50+
if i == length-1 and "punc" in data and text[i].startswith("vad:"):
51+
vad = x[-1][4:]
52+
if len(vad) == 0:
53+
vad = -1
54+
else:
55+
vad = int(vad)
56+
elif x in vocab:
4757
token.append(vocab[x])
4858
else:
4959
token.append(vocab['<unk>'])
5060

61+
if "punc" in data and punc_dict is not None:
62+
punc_token = []
63+
for punc in data["punc"]:
64+
if punc in punc_dict:
65+
punc_token.append(punc_dict[punc])
66+
else:
67+
punc_token.append(punc_dict["_"])
68+
data["punc"] = np.array(punc_token)
69+
5170
data["text"] = np.array(token)
71+
if vad is not -2:
72+
data["vad_indexes"]=np.array([vad], dtype=np.int64)
5273
return data

funasr/datasets/preprocessor.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,103 @@ def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
704704
del data[self.split_text_name]
705705
return result
706706

707+
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
708+
def __init__(
709+
self,
710+
train: bool,
711+
token_type: List[str] = [None],
712+
token_list: List[Union[Path, str, Iterable[str]]] = [None],
713+
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
714+
text_cleaner: Collection[str] = None,
715+
g2p_type: str = None,
716+
unk_symbol: str = "<unk>",
717+
space_symbol: str = "<space>",
718+
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
719+
delimiter: str = None,
720+
rir_scp: str = None,
721+
rir_apply_prob: float = 1.0,
722+
noise_scp: str = None,
723+
noise_apply_prob: float = 1.0,
724+
noise_db_range: str = "3_10",
725+
speech_volume_normalize: float = None,
726+
speech_name: str = "speech",
727+
text_name: List[str] = ["text"],
728+
vad_name: str = "vad_indexes",
729+
):
730+
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
731+
super().__init__(
732+
train=train,
733+
token_type=token_type[0],
734+
token_list=token_list[0],
735+
bpemodel=bpemodel[0],
736+
text_cleaner=text_cleaner,
737+
g2p_type=g2p_type,
738+
unk_symbol=unk_symbol,
739+
space_symbol=space_symbol,
740+
non_linguistic_symbols=non_linguistic_symbols,
741+
delimiter=delimiter,
742+
speech_name=speech_name,
743+
text_name=text_name[0],
744+
rir_scp=rir_scp,
745+
rir_apply_prob=rir_apply_prob,
746+
noise_scp=noise_scp,
747+
noise_apply_prob=noise_apply_prob,
748+
noise_db_range=noise_db_range,
749+
speech_volume_normalize=speech_volume_normalize,
750+
)
751+
752+
assert (
753+
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
754+
), "token_type, token_list, bpemodel, or processing text_name mismatched"
755+
self.num_tokenizer = len(token_type)
756+
self.tokenizer = []
757+
self.token_id_converter = []
758+
759+
for i in range(self.num_tokenizer):
760+
if token_type[i] is not None:
761+
if token_list[i] is None:
762+
raise ValueError("token_list is required if token_type is not None")
763+
764+
self.tokenizer.append(
765+
build_tokenizer(
766+
token_type=token_type[i],
767+
bpemodel=bpemodel[i],
768+
delimiter=delimiter,
769+
space_symbol=space_symbol,
770+
non_linguistic_symbols=non_linguistic_symbols,
771+
g2p_type=g2p_type,
772+
)
773+
)
774+
self.token_id_converter.append(
775+
TokenIDConverter(
776+
token_list=token_list[i],
777+
unk_symbol=unk_symbol,
778+
)
779+
)
780+
else:
781+
self.tokenizer.append(None)
782+
self.token_id_converter.append(None)
783+
784+
self.text_cleaner = TextCleaner(text_cleaner)
785+
self.text_name = text_name # override the text_name from CommonPreprocessor
786+
self.vad_name = vad_name
787+
788+
def _text_process(
789+
self, data: Dict[str, Union[str, np.ndarray]]
790+
) -> Dict[str, np.ndarray]:
791+
for i in range(self.num_tokenizer):
792+
text_name = self.text_name[i]
793+
if text_name in data and self.tokenizer[i] is not None:
794+
text = data[text_name]
795+
text = self.text_cleaner(text)
796+
tokens = self.tokenizer[i].text2tokens(text)
797+
if "vad:" in tokens[-1]:
798+
vad = tokens[-1][4:]
799+
tokens = tokens[:-1]
800+
if len(vad) == 0:
801+
vad = -1
802+
else:
803+
vad = int(vad)
804+
data[self.vad_name] = np.array([vad], dtype=np.int64)
805+
text_ints = self.token_id_converter[i].tokens2ids(tokens)
806+
data[text_name] = np.array(text_ints, dtype=np.int64)

funasr/modules/attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,18 @@ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
439439
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
440440
return att_outs + fsmn_memory
441441

442+
class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
443+
def __init__(self, *args, **kwargs):
444+
super().__init__(*args, **kwargs)
445+
446+
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
447+
q_h, k_h, v_h, v = self.forward_qkv(x)
448+
fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
449+
q_h = q_h * self.d_k ** (-0.5)
450+
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
451+
att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
452+
return att_outs + fsmn_memory
453+
442454
class MultiHeadedAttentionSANMDecoder(nn.Module):
443455
"""Multi-Head Attention layer.
444456

funasr/modules/mask.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,20 @@ def target_mask(ys_in_pad, ignore_id):
3333
ys_mask = ys_in_pad != ignore_id
3434
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
3535
return ys_mask.unsqueeze(-2) & m
36+
37+
def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
38+
"""Create mask for decoder self-attention.
39+
40+
:param int size: size of mask
41+
:param int vad_pos: index of vad index
42+
:param str device: "cpu" or "cuda" or torch.Tensor.device
43+
:param torch.dtype dtype: result dtype
44+
:rtype: torch.Tensor (B, Lmax, Lmax)
45+
"""
46+
ret = torch.ones(size, size, device=device, dtype=dtype)
47+
if vad_pos <= 0 or vad_pos >= size:
48+
return ret
49+
sub_corner = torch.zeros(
50+
vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
51+
ret[0:vad_pos - 1, vad_pos:] = sub_corner
52+
return ret

funasr/punctuation/abs_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
2525
@abstractmethod
2626
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2727
raise NotImplementedError
28+
29+
@abstractmethod
30+
def with_vad(self) -> bool:
31+
raise NotImplementedError

0 commit comments

Comments
 (0)