@@ -704,3 +704,103 @@ def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
704704 del data [self .split_text_name ]
705705 return result
706706
707+ class PuncTrainTokenizerCommonPreprocessor (CommonPreprocessor ):
708+ def __init__ (
709+ self ,
710+ train : bool ,
711+ token_type : List [str ] = [None ],
712+ token_list : List [Union [Path , str , Iterable [str ]]] = [None ],
713+ bpemodel : List [Union [Path , str , Iterable [str ]]] = [None ],
714+ text_cleaner : Collection [str ] = None ,
715+ g2p_type : str = None ,
716+ unk_symbol : str = "<unk>" ,
717+ space_symbol : str = "<space>" ,
718+ non_linguistic_symbols : Union [Path , str , Iterable [str ]] = None ,
719+ delimiter : str = None ,
720+ rir_scp : str = None ,
721+ rir_apply_prob : float = 1.0 ,
722+ noise_scp : str = None ,
723+ noise_apply_prob : float = 1.0 ,
724+ noise_db_range : str = "3_10" ,
725+ speech_volume_normalize : float = None ,
726+ speech_name : str = "speech" ,
727+ text_name : List [str ] = ["text" ],
728+ vad_name : str = "vad_indexes" ,
729+ ):
730+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
731+ super ().__init__ (
732+ train = train ,
733+ token_type = token_type [0 ],
734+ token_list = token_list [0 ],
735+ bpemodel = bpemodel [0 ],
736+ text_cleaner = text_cleaner ,
737+ g2p_type = g2p_type ,
738+ unk_symbol = unk_symbol ,
739+ space_symbol = space_symbol ,
740+ non_linguistic_symbols = non_linguistic_symbols ,
741+ delimiter = delimiter ,
742+ speech_name = speech_name ,
743+ text_name = text_name [0 ],
744+ rir_scp = rir_scp ,
745+ rir_apply_prob = rir_apply_prob ,
746+ noise_scp = noise_scp ,
747+ noise_apply_prob = noise_apply_prob ,
748+ noise_db_range = noise_db_range ,
749+ speech_volume_normalize = speech_volume_normalize ,
750+ )
751+
752+ assert (
753+ len (token_type ) == len (token_list ) == len (bpemodel ) == len (text_name )
754+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
755+ self .num_tokenizer = len (token_type )
756+ self .tokenizer = []
757+ self .token_id_converter = []
758+
759+ for i in range (self .num_tokenizer ):
760+ if token_type [i ] is not None :
761+ if token_list [i ] is None :
762+ raise ValueError ("token_list is required if token_type is not None" )
763+
764+ self .tokenizer .append (
765+ build_tokenizer (
766+ token_type = token_type [i ],
767+ bpemodel = bpemodel [i ],
768+ delimiter = delimiter ,
769+ space_symbol = space_symbol ,
770+ non_linguistic_symbols = non_linguistic_symbols ,
771+ g2p_type = g2p_type ,
772+ )
773+ )
774+ self .token_id_converter .append (
775+ TokenIDConverter (
776+ token_list = token_list [i ],
777+ unk_symbol = unk_symbol ,
778+ )
779+ )
780+ else :
781+ self .tokenizer .append (None )
782+ self .token_id_converter .append (None )
783+
784+ self .text_cleaner = TextCleaner (text_cleaner )
785+ self .text_name = text_name # override the text_name from CommonPreprocessor
786+ self .vad_name = vad_name
787+
788+ def _text_process (
789+ self , data : Dict [str , Union [str , np .ndarray ]]
790+ ) -> Dict [str , np .ndarray ]:
791+ for i in range (self .num_tokenizer ):
792+ text_name = self .text_name [i ]
793+ if text_name in data and self .tokenizer [i ] is not None :
794+ text = data [text_name ]
795+ text = self .text_cleaner (text )
796+ tokens = self .tokenizer [i ].text2tokens (text )
797+ if "vad:" in tokens [- 1 ]:
798+ vad = tokens [- 1 ][4 :]
799+ tokens = tokens [:- 1 ]
800+ if len (vad ) == 0 :
801+ vad = - 1
802+ else :
803+ vad = int (vad )
804+ data [self .vad_name ] = np .array ([vad ], dtype = np .int64 )
805+ text_ints = self .token_id_converter [i ].tokens2ids (tokens )
806+ data [text_name ] = np .array (text_ints , dtype = np .int64 )
0 commit comments