Skip to content

Commit

Permalink
fix: seeing match_pattern or intermediate_form is an error
Browse files Browse the repository at this point in the history
  • Loading branch information
dhdaines authored and joanise committed Sep 12, 2024
1 parent b786567 commit 302a0b8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
13 changes: 12 additions & 1 deletion g2p/mappings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def model_post_init(self, *_args, **_kwargs) -> None:
if self.rules_path is not None and not self.rules:
# make sure self.rules is always a List[Rule] like we say it is!
self.rules = [Rule(**obj) for obj in load_from_file(self.rules_path)]
# This is required so that we don't keep escaping special characters for example
# Process the rules, keeping only non-empty ones, and
# expanding abbreviations. This is also required so that
# we don't keep escaping special characters for example
self.rules = self.process_model_specs()
elif self.type == MAPPING_TYPE.lexicon:
# load alignments from path
Expand Down Expand Up @@ -190,6 +192,15 @@ def apply_to_attributes(rule: Rule, func: Callable, *attrs):

non_empty_mappings: List[Rule] = []
for i, rule in enumerate(self.rules):
# We explicitly exclude match_pattern and
# intermediate_form when saving rules. Seeing either of
# them is a programmer error.
assert (
rule.match_pattern is None
), "Either match_pattern was specified explicitly or process_model_specs was called more than once"
assert (
rule.intermediate_form is None
), "Either intermediate_form was specified explicitly or process_model_specs was called more than once"
# Expand Abbreviations
if self.abbreviations:
apply_to_attributes(
Expand Down
17 changes: 10 additions & 7 deletions g2p/mappings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ class Rule(BaseModel):
prevent_feeding: bool = False
"""Whether to prevent the rule from feeding other rules"""

match_pattern: Optional[Pattern] = None
"""An automatically generated match_pattern based on the rule_input, context_before and context_after"""

intermediate_form: Optional[str] = None
"""An optional intermediate form. Should be automatically generated only when prevent_feeding is True"""
match_pattern: Optional[Pattern] = Field(
None,
exclude=True,
description="""An automatically generated match_pattern based on the rule_input, context_before and context_after""",
)

intermediate_form: Optional[str] = Field(
None,
exclude=True,
description="""An intermediate form, automatically generated only when prevent_feeding is True""",
)
comment: Optional[str] = None
"""An optional comment about the rule."""

Expand All @@ -69,8 +74,6 @@ def export_to_dict(
self, exclude=None, exclude_none=True, exclude_defaults=True, by_alias=True
):
"""All the options for exporting are tedious to keep track of so this is a helper function"""
if exclude is None:
exclude = {"match_pattern": True, "intermediate_form": True}
return self.model_dump(
exclude=exclude,
exclude_none=exclude_none,
Expand Down
20 changes: 20 additions & 0 deletions g2p/tests/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import json
import os
import re
import unicodedata as ud
from contextlib import redirect_stderr
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -406,6 +407,25 @@ def test_g2p_studio_csv(self):
)
os.unlink(tf.name)

def test_no_reprocess(self):
"""Ensure that attempting to reprocess a mapping is an error."""
with self.assertRaises(AssertionError):
self.test_mapping_norm.process_model_specs()
with self.assertRaises(ValidationError):
_ = Mapping(
rules=[{"in": "a", "out": "b", "match_pattern": re.compile("XOR OTA")}]
)
with self.assertRaises(ValidationError):
_ = Mapping(
rules=[
{
"in": "a",
"out": "b",
"intermediate_form": re.compile("HACKEM MUCHE"),
}
]
)


if __name__ == "__main__":
main()

0 comments on commit 302a0b8

Please sign in to comment.