Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable speed perturbation by default #1176

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
torch.set_num_interop_threads(1)


def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
Expand Down Expand Up @@ -86,9 +86,12 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
if speed_perturb:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
Expand All @@ -109,7 +112,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)

parser.add_argument(
"--speed-perturb",
type=bool,
JinZr marked this conversation as resolved.
Show resolved Hide resolved
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()


Expand All @@ -119,4 +127,6 @@ def get_args():
logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
compute_fbank_aidatatang_200zh(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)
22 changes: 16 additions & 6 deletions egs/aishell/ASR/local/compute_fbank_aishell.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell(num_mel_bins: int = 80):
def compute_fbank_aishell(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
Expand Down Expand Up @@ -82,9 +82,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
if speed_perturb:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
Expand All @@ -104,7 +107,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)

parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()


Expand All @@ -114,4 +122,6 @@ def get_args():
logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_aishell(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)
21 changes: 16 additions & 5 deletions egs/aishell2/ASR/local/compute_fbank_aishell2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell2(num_mel_bins: int = 80):
def compute_fbank_aishell2(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
Expand Down Expand Up @@ -82,9 +82,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
if speed_perturb:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
Expand All @@ -104,6 +107,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

return parser.parse_args()

Expand All @@ -114,4 +123,6 @@ def get_args():
logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)
22 changes: 17 additions & 5 deletions egs/aishell4/ASR/local/compute_fbank_aishell4.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell4(num_mel_bins: int = 80):
def compute_fbank_aishell4(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
Expand Down Expand Up @@ -84,9 +84,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
if speed_perturb:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)

cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
Expand All @@ -113,6 +117,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

return parser.parse_args()

Expand All @@ -123,4 +133,6 @@ def get_args():
logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)
21 changes: 16 additions & 5 deletions egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
torch.set_num_interop_threads(1)


def compute_fbank_alimeeting(num_mel_bins: int = 80):
def compute_fbank_alimeeting(num_mel_bins: int = 80, speed_perturb: bool = False):
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
Expand Down Expand Up @@ -83,9 +83,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
if speed_perturb:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set))

Expand Down Expand Up @@ -114,6 +117,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)

return parser.parse_args()

Expand All @@ -124,4 +133,6 @@ def get_args():
logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb
)
30 changes: 23 additions & 7 deletions egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import re
from pathlib import Path
Expand Down Expand Up @@ -45,7 +46,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None


def preprocess_wenet_speech():
def preprocess_wenet_speech(speed_perturb: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -111,19 +112,34 @@ def preprocess_wenet_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST_NET", "TEST_MEETING"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if speed_perturb:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--speed-perturb",
type=bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()


def main():
setup_logger(log_filename="./log-preprocess-wenetspeech")

preprocess_wenet_speech()
args = get_args()
preprocess_wenet_speech(speed_perturb=args.speed_perturb)
logging.info("Done")


Expand Down