Skip to content

Commit

Permalink
Generate test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
naglis committed Nov 13, 2024
1 parent 821c4f6 commit 98a6783
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 41 deletions.
74 changes: 41 additions & 33 deletions aeneas/tests/base_ttswrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,26 @@
import tempfile
import contextlib
import typing
import itertools

from aeneas.exacttiming import TimeValue
from aeneas.textfile import TextFile, TextFragment
from aeneas.ttswrappers.basettswrapper import BaseTTSWrapper
from aeneas.runtimeconfiguration import RuntimeConfiguration


class SynthesizeCase(typing.NamedTuple):
c_ext: bool
cew_subprocess: bool
cache: bool


class TestBaseTTSWrapper(unittest.TestCase):
def test_not_implemented(self):
with self.assertRaises(NotImplementedError):
BaseTTSWrapper()


class BaseTTSWrapperCase(unittest.TestCase):
TTS = ""
TTS_PATH = ""

Expand All @@ -43,20 +55,20 @@ class TestBaseTTSWrapper(unittest.TestCase):
def synthesize(
self,
text_file,
ofp=None,
quit_after=None,
backwards=False,
zero_length=False,
ofp: str | None = None,
quit_after: TimeValue | None = None,
backwards: bool = False,
zero_length: bool = False,
expected_exc=None,
):
if (
(self.TTS == "")
or (self.TTS_PATH == "")
or (not os.path.exists(self.TTS_PATH))
):
return

def inner(c_ext, cew_subprocess, cache):
if not self.TTS:
self.skipTest("`self.TTS` is not set")
elif not self.TTS_PATH:
self.skipTest("`self.TTS_PATH` is not set")
elif not os.path.isfile(self.TTS_PATH):
self.skipTest(f"`self.TTS_PATH` ({self.TTS_PATH}) does not exist")

def inner(case: SynthesizeCase):
with contextlib.ExitStack() as exit_stack:
if ofp is None:
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav")
Expand All @@ -69,36 +81,36 @@ def inner(c_ext, cew_subprocess, cache):
rconf = RuntimeConfiguration()
rconf[RuntimeConfiguration.TTS] = self.TTS
rconf[RuntimeConfiguration.TTS_PATH] = self.TTS_PATH
rconf[RuntimeConfiguration.C_EXTENSIONS] = c_ext
rconf[RuntimeConfiguration.CEW_SUBPROCESS_ENABLED] = cew_subprocess
rconf[RuntimeConfiguration.TTS_CACHE] = cache
rconf[RuntimeConfiguration.C_EXTENSIONS] = case.c_ext
rconf[RuntimeConfiguration.CEW_SUBPROCESS_ENABLED] = (
case.cew_subprocess
)
rconf[RuntimeConfiguration.TTS_CACHE] = case.cache
tts_engine = self.TTS_CLASS(rconf=rconf)
anchors, total_time, num_chars = tts_engine.synthesize_multiple(
text_file, output_file_path, quit_after, backwards
)
if cache:
if case.cache:
tts_engine.clear_cache()

if zero_length:
self.assertEqual(total_time, 0.0)
else:
self.assertGreater(total_time, 0.0)

except (OSError, TypeError, UnicodeDecodeError, ValueError) as exc:
if cache and tts_engine is not None:
if case.cache and tts_engine is not None:
tts_engine.clear_cache()
with self.assertRaises(expected_exc):
raise exc

if self.TTS == "espeak":
for c_ext, cew_subprocess, cache in itertools.product(
[True, False], repeat=3
):
inner(c_ext=c_ext, cew_subprocess=cew_subprocess, cache=cache)
elif self.TTS == "festival":
for c_ext, cache in itertools.product([True, False], repeat=2):
inner(c_ext=c_ext, cew_subprocess=False, cache=cache)
else:
for cache in [True, False]:
inner(c_ext=True, cew_subprocess=False, cache=cache)
for case in self.iter_synthesize_cases():
with self.subTest(case=case):
inner(case)

def iter_synthesize_cases(self) -> typing.Iterator[SynthesizeCase]:
yield SynthesizeCase(c_ext=True, cew_subprocess=False, cache=True)
yield SynthesizeCase(c_ext=True, cew_subprocess=False, cache=False)

def tfl(self, frags):
tfl = TextFile()
Expand All @@ -108,10 +120,6 @@ def tfl(self, frags):
)
return tfl

def test_not_implemented(self):
with self.assertRaises(NotImplementedError):
BaseTTSWrapper()

def test_use_cache(self):
if self.TTS == "":
self.skipTest("`self.TTS` is not set")
Expand Down
4 changes: 2 additions & 2 deletions aeneas/tests/test_espeakngttswrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.


from aeneas.tests.base_ttswrapper import TestBaseTTSWrapper
from aeneas.ttswrappers.espeakngttswrapper import ESPEAKNGTTSWrapper
from aeneas.tests.base_ttswrapper import BaseTTSWrapperCase


class TestESPEAKNGTTSWrapper(TestBaseTTSWrapper):
class TestESPEAKNGTTSWrapper(BaseTTSWrapperCase):
TTS = "espeak-ng"
TTS_PATH = "/usr/bin/espeak-ng"
TTS_CLASS = ESPEAKNGTTSWrapper
Expand Down
9 changes: 7 additions & 2 deletions aeneas/tests/test_espeakttswrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import itertools

from aeneas.tests.base_ttswrapper import TestBaseTTSWrapper
from aeneas.tests.base_ttswrapper import BaseTTSWrapperCase, SynthesizeCase
from aeneas.ttswrappers.espeakttswrapper import ESPEAKTTSWrapper


class TestESPEAKTTSWrapper(TestBaseTTSWrapper):
class TestESPEAKTTSWrapper(BaseTTSWrapperCase):
TTS = "espeak"
TTS_PATH = "/usr/bin/espeak"
TTS_CLASS = ESPEAKTTSWrapper
TTS_LANGUAGE = ESPEAKTTSWrapper.ENG
TTS_LANGUAGE_VARIATION = ESPEAKTTSWrapper.ENG_GBR

def iter_synthesize_cases(self):
for v in itertools.product([True, False], repeat=3):
yield SynthesizeCase(*v)

def test_multiple_replace_language(self):
tfl = self.tfl(
[(ESPEAKTTSWrapper.UKR, ["Временами Сашке хотелось перестать делать то"])]
Expand Down
9 changes: 7 additions & 2 deletions aeneas/tests/test_festivalttswrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import itertools

from aeneas.tests.base_ttswrapper import TestBaseTTSWrapper
from aeneas.tests.base_ttswrapper import BaseTTSWrapperCase, SynthesizeCase
from aeneas.ttswrappers.festivalttswrapper import FESTIVALTTSWrapper


class TestFESTIVALTTSWrapper(TestBaseTTSWrapper):
class TestFESTIVALTTSWrapper(BaseTTSWrapperCase):
TTS = "festival"
TTS_PATH = "/usr/bin/text2wave"
TTS_CLASS = FESTIVALTTSWrapper
TTS_LANGUAGE = FESTIVALTTSWrapper.ENG
TTS_LANGUAGE_VARIATION = FESTIVALTTSWrapper.ENG_GBR

def iter_synthesize_cases(self):
for c_ext, cache in itertools.product([True, False], repeat=2):
yield SynthesizeCase(c_ext=c_ext, cew_subprocess=False, cache=cache)
4 changes: 2 additions & 2 deletions aeneas/tests/test_macosttswrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.


from aeneas.tests.base_ttswrapper import TestBaseTTSWrapper
from aeneas.tests.base_ttswrapper import BaseTTSWrapperCase
from aeneas.ttswrappers.macosttswrapper import MacOSTTSWrapper


class TestESPEAKNGTTSWrapper(TestBaseTTSWrapper):
class TestMacOSTTSWrapper(BaseTTSWrapperCase):
TTS = "macos"
TTS_PATH = "/usr/bin/say"
TTS_CLASS = MacOSTTSWrapper
Expand Down

0 comments on commit 98a6783

Please sign in to comment.