Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/UnitTest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: FunASR Unit Test
run-name: ${{ github.actor }} is testing out FunASR Unit Test 🚀
on:
pull_request:
branches:
- main
push:
branches:
- dev_wjm
- dev_jy

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio
pip install "modelscope[audio_asr]" --upgrade -f \
https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e ./
- name: Testing
run:
python tests/run_test.py
51 changes: 51 additions & 0 deletions tests/run_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python

import argparse
import os
import sys
import unittest
from fnmatch import fnmatch


def gather_test_cases(test_dir, pattern, list_tests):
case_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
case_list.append(file)

test_suite = unittest.TestSuite()

for case in case_list:
test_case = unittest.defaultTestLoader.discover(start_dir=test_dir, pattern=case)
test_suite.addTest(test_case)
if hasattr(test_case, '__iter__'):
for subcase in test_case:
if list_tests:
print(subcase)
else:
if list_tests:
print(test_case)
return test_suite


def main(args):
runner = unittest.TextTestRunner()
test_suite = gather_test_cases(os.path.abspath(args.test_dir), args.pattern, args.list_tests)
if not args.list_tests:
result = runner.run(test_suite)
if len(result.failures) > 0:
sys.exit(len(result.failures))
if len(result.errors) > 0:
sys.exit(len(result.errors))


if __name__ == '__main__':
parser = argparse.ArgumentParser('test runner')
parser.add_argument('--list_tests', action='store_true', help='list all tests')
parser.add_argument('--pattern', default='test_*.py', help='test file pattern')
parser.add_argument('--test_dir', default='tests', help='directory to be tested')
parser.add_argument('--disable_profile', action='store_true', help='disable profiling')
args = parser.parse_args()
print(f'working dir: {os.getcwd()}')
main(args)
47 changes: 47 additions & 0 deletions tests/test_inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

class TestInferencePipelines(unittest.TestCase):
def test_funasr_path(self):
import funasr
import os
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))

def test_asr_inference_pipeline(self):
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))

def test_asr_inference_pipeline_with_vad_punc(self):
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
vad_model_revision="v1.1.8",
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
punc_model_revision="v1.1.6")
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav')
logger.info("asr inference with vad punc result: {0}".format(rec_result))

def test_vad_inference_pipeline(self):
inference_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
model_revision='v1.1.8',
)
segments_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
logger.info("vad inference result: {0}".format(segments_result))


if __name__ == '__main__':
unittest.main()