Skip to content

Commit

Permalink
refactor: let us parse and load XML in just one place
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Jun 19, 2024
1 parent 69458c9 commit 94d49ae
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 58 deletions.
5 changes: 2 additions & 3 deletions readalongs/text/make_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from readalongs._version import __version__
from readalongs.log import LOGGER
from readalongs.text.util import parse_xml

JS_BUNDLE_URL = "https://unpkg.com/@readalongs/web-component@^1.4.0/dist/bundle.js"
FONTS_BUNDLE_URL = (
Expand Down Expand Up @@ -65,9 +66,7 @@ def encode_from_path(path: str) -> str:
with open(path, "rb") as f:
path_bytes = f.read()
if str(path).endswith("xml") or str(path).endswith(".readalong"):
root = etree.fromstring(
path_bytes, parser=etree.XMLParser(resolve_entities=False)
)
root = parse_xml(path_bytes)
for img in root.xpath("//graphic"):
url = img.get("url")
if url.startswith("http"):
Expand Down
6 changes: 3 additions & 3 deletions readalongs/text/make_smil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import chevron
from lxml import etree

from readalongs.text.util import parse_xml

SMIL_TEMPLATE = """\
<smil xmlns="http://www.w3.org/ns/SMIL" version="3.0">
<body>
Expand Down Expand Up @@ -74,9 +76,7 @@ def parse_smil(formatted_smil: str) -> List[dict]:
please_msg = "Please make sure your SMIL file is valid."

try:
xml = etree.fromstring(
formatted_smil, parser=etree.XMLParser(resolve_entities=False)
)
xml = parse_xml(formatted_smil)
except etree.ParseError as e:
raise ValueError(f"Invalid SMIL file: {e}. {please_msg}")
ns = {"smil": "http://www.w3.org/ns/SMIL"}
Expand Down
21 changes: 20 additions & 1 deletion readalongs/text/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,29 @@ def load_xml(input_path: Union[str, IO]) -> etree.ElementTree:
"""
# resolve_entities=False is a safety issue, prevents XML bombs.
return etree.parse(
input_path, parser=etree.XMLParser(resolve_entities=False)
input_path,
parser=etree.XMLParser(resolve_entities=False),
).getroot()


def parse_xml(xml_text: Union[str, bytes]) -> etree.ElementTree:
"""Safely parse an XML as input text
Return: the root of the XML etree
Args:
xml_text: the XML text contents to parse
Raises:
etree.ParseError: if there is a problem parsing the XML contents
"""
return etree.fromstring(
xml_text if isinstance(xml_text, bytes) else bytes(xml_text, encoding="utf8"),
# resolve_entities=False is a safety issue, prevents XML bombs.
parser=etree.XMLParser(resolve_entities=False),
)


def load_xml_zip(zip_path, input_path) -> etree.ElementTree:
with zipfile.ZipFile(zip_path, "r") as fin_zip:
with fin_zip.open(input_path, "r") as fin:
Expand Down
12 changes: 2 additions & 10 deletions readalongs/waveform2svg/audio_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import os
from io import open

import librosa
import librosa # type: ignore
import numpy as np
from lxml import etree

from readalongs.text.util import xpath_default
from readalongs.text.util import load_xml, xpath_default

SAMPLE_RATE = 16000

Expand Down Expand Up @@ -75,13 +74,6 @@ def ensure_dirs(path):
os.makedirs(dirname)


def load_xml(input_path):
with open(input_path, "r", encoding="utf-8") as fin:
return etree.fromstring(
fin.read(), parser=etree.XMLParser(resolve_entities=False)
)


def save_txt(output_path, txt):
ensure_dirs(output_path)
with open(output_path, "w", encoding="utf-8") as fout:
Expand Down
12 changes: 2 additions & 10 deletions readalongs/waveform2svg/units2svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
from collections import OrderedDict

import chevron
import librosa
import librosa # type: ignore
from audio_util import save_txt
from lxml import etree

from readalongs.text.util import xpath_default
from readalongs.text.util import load_xml, xpath_default

FMIN = 80
FMAX = 1000
Expand Down Expand Up @@ -57,13 +56,6 @@ def render_svg(data, width=512, height=100, radius=4):
return chevron.render(SVG_TEMPLATE, result)


def load_xml(input_path):
with open(input_path, "r", encoding="utf-8") as fin:
return etree.fromstring(
fin.read(), parser=etree.XMLParser(resolve_entities=False)
)


def parse_smil(input_path):
"""Figure out the overall start and end of every unit, even if the whole
sequence plays out over multiple audio files"""
Expand Down
3 changes: 1 addition & 2 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def test_arbitrary_xml(self):
)

# if xpath isn't valid, log warning
log_output = io.StringIO()
with redirect_stderr(log_output):
with redirect_stderr(io.StringIO()) as log_output:
new_xml = add_supplementary_xml(
self.readalong,
{
Expand Down
11 changes: 6 additions & 5 deletions test/test_dna_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from readalongs.text import tokenize_xml
from readalongs.text.add_ids_to_xml import add_ids
from readalongs.text.util import parse_xml


class TestDNAText(BasicTestCase):
Expand All @@ -21,7 +22,7 @@ def test_tok_all_words(self):
<s>Bonjour! Comment ça va?</s>
<s>Voici une deuxième phrase.</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
as_txt = etree.tounicode(tokenized)
# print(etree.tounicode(tokenized))
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_tok_some_words(self):
<s do-not-align="TRUE">Voici une deuxième phrase.</s>
<s>Un <foo do-not-align="1">mot ou deux</foo> à exclure.</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
as_txt = etree.tounicode(tokenized)
# print('as_txt="' + as_txt +'"')
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_tok_div_p_s(self):
<p> <s>Trois phrases.</s> </p>
</div>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
as_txt = etree.tounicode(tokenized)
# print('as_txt="' + as_txt +'"')
Expand Down Expand Up @@ -141,15 +142,15 @@ def test_dna_word(self):
"""You can't have a DNA <w> element, that's reserved for tokens to align"""

txt = """<s xml:lang="fra">Une <w do-not-align="true">exclude</w> phrase.</s>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
self.assertRaises(RuntimeError, add_ids, tokenized)

def test_dna_word_nested(self):
"""You also can't have a <w> element inside a DNA element"""

txt = """<s xml:lang="fra">Une <foo do-not-align="true"><bar><w>exclude</w></bar></foo> phrase.</s>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
self.assertRaises(RuntimeError, add_ids, tokenized)

Expand Down
11 changes: 6 additions & 5 deletions test/test_g2p_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from readalongs.cli import align, g2p, make_xml, tokenize
from readalongs.log import LOGGER
from readalongs.text.convert_xml import convert_xml
from readalongs.text.util import parse_xml


def run_convert_xml(input_string):
"""wrap convert_xml to make unit testing easier"""
return etree.tounicode(convert_xml(etree.fromstring(input_string))[0])
return etree.tounicode(convert_xml(parse_xml(input_string))[0])


def two_xml_elements(xml_text):
Expand Down Expand Up @@ -362,7 +363,7 @@ def test_convert_xml_with_newlines(self):
"""Newlines inside words are weird, but they should not cause errors"""

def compact_arpabet(xml_string: str) -> str:
etree_root = etree.fromstring(xml_string)
etree_root = parse_xml(xml_string)
arpabet = etree_root[0].attrib["ARPABET"]
return re.sub(r"\s+", " ", arpabet)

Expand Down Expand Up @@ -438,22 +439,22 @@ def test_convert_xml_subwords(self):

def test_convert_xml_invalid(self):
"""test readalongs.text.convert_xml.convert_xml() with invalid input"""
xml = etree.fromstring('<s><w ARPABET="V AA L IY D">valid</w></s>')
xml = parse_xml('<s><w ARPABET="V AA L IY D">valid</w></s>')
c_xml, valid = convert_xml(xml)
self.assertEqual(
etree.tounicode(c_xml), '<s><w ARPABET="V AA L IY D">valid</w></s>'
)
self.assertTrue(valid, "convert_xml with valid pre-g2p'd text")

xml = etree.fromstring('<s><w ARPABET="invalid">invalid</w></s>')
xml = parse_xml('<s><w ARPABET="invalid">invalid</w></s>')
c_xml, valid = convert_xml(xml)
self.assertEqual(
etree.tounicode(c_xml), '<s><w ARPABET="invalid">invalid</w></s>'
)
self.assertFalse(valid, "convert_xml with invalid pre-g2p'd text")

def test_invalid_langs_in_xml(self):
xml = etree.fromstring(
xml = parse_xml(
"""
<s>
<w lang="eng" fallback-langs="foo">français falls back to invalid foo</w>
Expand Down
25 changes: 18 additions & 7 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
load_xml,
load_xml_zip,
parse_time,
parse_xml,
save_txt,
save_xml,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_get_attrib_recursive(self):
</text>
</read-along>
"""
xml = etree.fromstring(raw_xml)
xml = parse_xml(raw_xml)
for i, s, lang in zip(
itertools.count(),
xml.xpath(".//s"),
Expand Down Expand Up @@ -176,23 +177,23 @@ def test_joiner_callback(self):

def test_get_word_text(self):
self.assertEqual(
get_word_text(etree.fromstring("<w>basicword</w>")),
get_word_text(parse_xml("<w>basicword</w>")),
"basicword",
)
self.assertEqual(
get_word_text(etree.fromstring("<w><subw>subwcase</subw></w>")),
get_word_text(parse_xml("<w><subw>subwcase</subw></w>")),
"subwcase",
)
self.assertEqual(
get_word_text(etree.fromstring("<w><syl>syl1</syl><syl>syl2</syl></w>")),
get_word_text(parse_xml("<w><syl>syl1</syl><syl>syl2</syl></w>")),
"syl1syl2",
)
self.assertEqual(
get_word_text(etree.fromstring("<w>text<subw>sub</subw>tail</w>")),
get_word_text(parse_xml("<w>text<subw>sub</subw>tail</w>")),
"textsubtail",
)
self.assertEqual(
get_word_text(etree.fromstring("<w><a>a<b>b</b>c</a>d</w>")),
get_word_text(parse_xml("<w><a>a<b>b</b>c</a>d</w>")),
"abcd",
)

Expand Down Expand Up @@ -242,9 +243,19 @@ def test_load_xml_errors(self):
# b'<explode>AAAAAAAAAAAAAAAA</explode>'
# See https://en.wikipedia.org/wiki/Billion_laughs_attack

def test_parse_xml(self):
xml_text = '<foo attrib="value">text</foo>'
xml = parse_xml(xml_text)
xml2 = parse_xml(bytes(xml_text, encoding="latin1"))
self.assertEqual(etree.tostring(xml), etree.tostring(xml2))

malformed_xml_text = "<foo attrib="
with self.assertRaises(etree.ParseError):
xml = parse_xml(malformed_xml_text)

def test_save_xml(self):
xml_text = '<foo attrib="value">text</foo>'
xml = etree.fromstring(xml_text)
xml = parse_xml(xml_text)
filename = self.tempdir / "foo.readalong"
save_xml(filename, xml)
loaded_xml = load_xml(filename)
Expand Down
6 changes: 2 additions & 4 deletions test/test_silence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from unittest import main

from basic_test_case import BasicTestCase
from lxml import etree
from pydub import AudioSegment

from readalongs.cli import align
from readalongs.text.util import load_xml


class TestSilence(BasicTestCase):
Expand Down Expand Up @@ -41,9 +41,7 @@ def test_basic_silence_insertion(self):
self.assertEqual(results.exit_code, 0)
self.assertTrue(os.path.exists(os.path.join(output, "silence.m4a")))
# test silence spans in output xml
with open(os.path.join(output, "silence.readalong"), "rb") as f:
xml_bytes = f.read()
root = etree.fromstring(xml_bytes)
root = load_xml(os.path.join(output, "silence.readalong"))
silence_spans = root.xpath("//silence")
self.assertEqual(len(silence_spans), 3)
# test audio has correct amount of silence added
Expand Down
9 changes: 5 additions & 4 deletions test/test_tokenize_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lxml import etree

from readalongs.text import tokenize_xml
from readalongs.text.util import parse_xml


class TestTokenizer(TestCase):
Expand All @@ -21,7 +22,7 @@ def test_simple(self):
ref = """<document>
<s xml:lang="atj"><w>Kwei</w>! <w>Tan</w> <w>e</w> <w>ici</w> <w>matisihin</w>?</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
# print(etree.tounicode(tokenized))
self.assertEqual(etree.tounicode(tokenized), ref)
Expand All @@ -37,7 +38,7 @@ def test_mixed_lang(self):
<s xml:lang="atj"><w>Kwei</w>! <w>Tan</w> <w>e</w> <w>ici</w> <w>matisihin</w>?</s>
<s xml:lang="fra"><w>Bonjour</w>! <w>Comment</w> <w>ça</w> <w>va</w>?</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
# print(etree.tounicode(tokenized))
self.assertEqual(etree.tounicode(tokenized), ref)
Expand All @@ -53,7 +54,7 @@ def test_mixed_words(self):
<s xml:lang="atj">Kwei! (<w xml:lang="fra">Bonjour</w>!)</s>
<s xml:lang="atj">Tan e ici matisihin?</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
# print(etree.tounicode(tokenized))
self.assertEqual(etree.tounicode(tokenized), ref)
Expand All @@ -71,7 +72,7 @@ def test_comments(self):
<!--<s>comments</s> <w>should</w> <p>be ignored</p>-->
<s xml:lang="atj"><w>Tan</w> <w>e</w> <w>ici</w> <w>matisihin</w>?</s>
</document>"""
xml = etree.fromstring(txt)
xml = parse_xml(txt)
tokenized = tokenize_xml.tokenize_xml(xml)
# print(etree.tounicode(tokenized))
self.assertEqual(etree.tounicode(tokenized), ref)
Expand Down
6 changes: 2 additions & 4 deletions test/test_web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from unittest import main

from basic_test_case import BasicTestCase
from lxml import etree

from readalongs.log import LOGGER
from readalongs.text.add_ids_to_xml import add_ids
from readalongs.text.convert_xml import convert_xml
from readalongs.text.tokenize_xml import tokenize_xml
from readalongs.text.util import parse_xml
from readalongs.util import get_langs


Expand Down Expand Up @@ -86,9 +86,7 @@ def test_invalid_ras(self):

def test_create_grammar(self):
# Test the create grammar function
parsed = etree.fromstring(
bytes(self.slurp_data_file("ej-fra.readalong"), encoding="utf8")
)
parsed = parse_xml(self.slurp_data_file("ej-fra.readalong"))
tokenized = tokenize_xml(parsed)
ids_added = add_ids(tokenized)
g2ped, valid = convert_xml(ids_added)
Expand Down

0 comments on commit 94d49ae

Please sign in to comment.