Skip to content

Commit 939b9fe

Browse files
authored
Merge pull request #166 from magicharry/main
unit test pr new
2 parents a016617 + d694654 commit 939b9fe

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

.github/workflows/UnitTest.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: FunASR Unit Test
2+
run-name: ${{ github.actor }} is testing out FunASR Unit Test 🚀
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
push:
8+
branches:
9+
- dev_wjm
10+
- dev_jy
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ["3.7"]
19+
20+
steps:
21+
- uses: actions/checkout@v3
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install torch torchvision torchaudio
30+
pip install "modelscope[audio_asr]" --upgrade -f \
31+
https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
32+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
33+
pip install -e ./
34+
- name: Testing
35+
run:
36+
python tests/run_test.py

tests/run_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import os
5+
import sys
6+
import unittest
7+
from fnmatch import fnmatch
8+
9+
10+
def gather_test_cases(test_dir, pattern, list_tests):
11+
case_list = []
12+
for dirpath, dirnames, filenames in os.walk(test_dir):
13+
for file in filenames:
14+
if fnmatch(file, pattern):
15+
case_list.append(file)
16+
17+
test_suite = unittest.TestSuite()
18+
19+
for case in case_list:
20+
test_case = unittest.defaultTestLoader.discover(start_dir=test_dir, pattern=case)
21+
test_suite.addTest(test_case)
22+
if hasattr(test_case, '__iter__'):
23+
for subcase in test_case:
24+
if list_tests:
25+
print(subcase)
26+
else:
27+
if list_tests:
28+
print(test_case)
29+
return test_suite
30+
31+
32+
def main(args):
33+
runner = unittest.TextTestRunner()
34+
test_suite = gather_test_cases(os.path.abspath(args.test_dir), args.pattern, args.list_tests)
35+
if not args.list_tests:
36+
result = runner.run(test_suite)
37+
if len(result.failures) > 0:
38+
sys.exit(len(result.failures))
39+
if len(result.errors) > 0:
40+
sys.exit(len(result.errors))
41+
42+
43+
if __name__ == '__main__':
44+
parser = argparse.ArgumentParser('test runner')
45+
parser.add_argument('--list_tests', action='store_true', help='list all tests')
46+
parser.add_argument('--pattern', default='test_*.py', help='test file pattern')
47+
parser.add_argument('--test_dir', default='tests', help='directory to be tested')
48+
parser.add_argument('--disable_profile', action='store_true', help='disable profiling')
49+
args = parser.parse_args()
50+
print(f'working dir: {os.getcwd()}')
51+
main(args)

tests/test_inference_pipeline.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
3+
from modelscope.pipelines import pipeline
4+
from modelscope.utils.constant import Tasks
5+
from modelscope.utils.logger import get_logger
6+
7+
logger = get_logger()
8+
9+
class TestInferencePipelines(unittest.TestCase):
10+
def test_funasr_path(self):
11+
import funasr
12+
import os
13+
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
14+
15+
def test_asr_inference_pipeline(self):
16+
inference_pipeline = pipeline(
17+
task=Tasks.auto_speech_recognition,
18+
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
19+
rec_result = inference_pipeline(
20+
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
21+
logger.info("asr inference result: {0}".format(rec_result))
22+
23+
def test_asr_inference_pipeline_with_vad_punc(self):
24+
inference_pipeline = pipeline(
25+
task=Tasks.auto_speech_recognition,
26+
model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
27+
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
28+
vad_model_revision="v1.1.8",
29+
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
30+
punc_model_revision="v1.1.6")
31+
rec_result = inference_pipeline(
32+
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav')
33+
logger.info("asr inference with vad punc result: {0}".format(rec_result))
34+
35+
def test_vad_inference_pipeline(self):
36+
inference_pipeline = pipeline(
37+
task=Tasks.voice_activity_detection,
38+
model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
39+
model_revision='v1.1.8',
40+
)
41+
segments_result = inference_pipeline(
42+
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
43+
logger.info("vad inference result: {0}".format(segments_result))
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)