diff --git a/g2p/mappings/langs/kwk/config-g2p.yaml b/g2p/mappings/langs/kwk/config-g2p.yaml index 050995b2..59b3bfaa 100644 --- a/g2p/mappings/langs/kwk/config-g2p.yaml +++ b/g2p/mappings/langs/kwk/config-g2p.yaml @@ -56,6 +56,8 @@ mappings: out_lang: kwk-umista rule_ordering: apply-longest-first prevent_feeding: true + case_sensitive: false + preserve_case: true authors: - Fineen Davis - Olivia Chen diff --git a/g2p/mappings/langs/langs.json.gz b/g2p/mappings/langs/langs.json.gz index d6c8b09d..1c41c13f 100644 Binary files a/g2p/mappings/langs/langs.json.gz and b/g2p/mappings/langs/langs.json.gz differ diff --git a/g2p/mappings/utils.py b/g2p/mappings/utils.py index ab006fbf..bc938155 100644 --- a/g2p/mappings/utils.py +++ b/g2p/mappings/utils.py @@ -661,7 +661,13 @@ class _MappingModelDefinition(BaseModel): """Deprecated: Please use rule_ordering='as_written' """ case_sensitive: bool = True - """Lower all rules and conversion input""" + """When false, lowercase all rules and conversion input""" + + case_equivalencies: dict = {} + """List of case equivalencies for preserve_case that are not already in the Unicode standard""" + + preserve_case: bool = False + """Preserve source case in output (requires case_sensitive=False)""" escape_special: bool = False """Escape special characters in rules""" @@ -725,7 +731,7 @@ def check_mapping_types(self) -> "_MappingModelDefinition": and not self.rules and self.rules_path is None ): - LOGGER.warn( + LOGGER.warning( exceptions.MalformedMapping( "You have to either specify some rules or a path to a file containing rules." ) @@ -755,6 +761,27 @@ def validate_norm_form(cls, v): v = "none" return v + @field_validator("case_equivalencies", mode="before") + @classmethod + def validate_case_equivalencies(cls, v): + if not v or v is None: + v = {} + for lower_case, upper_case in v.items(): + if len(lower_case) != len(upper_case): + raise exceptions.MalformedMapping( + f"Sorry, the case equivalency between {lower_case} and {upper_case} is not valid because it is not the same length, please write rules such that any case equivalent is of equal length." + ) + return v + + @model_validator(mode="after") + def validate_preserve_case(self): + """preserve_case=True requires case_sensitive=False""" + if self.preserve_case and self.case_sensitive: + raise exceptions.MalformedMapping( + "Sorry, preserve_case=True requires case_sensitive=False." + ) + return self + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator("rules_path", "abbreviations_path", "alignments_path", pre=True) diff --git a/g2p/static/custom.js b/g2p/static/custom.js index c22addf8..ba0de2c3 100644 --- a/g2p/static/custom.js +++ b/g2p/static/custom.js @@ -53,6 +53,7 @@ function createSettings(index, data) { let include = 'checked'; let rule_ordering = ''; let case_sensitive = ''; + let preserve_case = ''; let escape_special = ''; let reverse = ''; let active = ''; @@ -76,6 +77,9 @@ function createSettings(index, data) { if (data['case_sensitive']) { case_sensitive = 'checked' } + if (data['preserve_case']) { + preserve_case = 'checked' + } if (data['escape_special']) { escape_special = 'checked' } @@ -115,6 +119,10 @@ function createSettings(index, data) { value='case_sensitive'> +
+ + +
@@ -161,9 +169,20 @@ function createSettings(index, data) { document.getElementById(`case_sensitive-${index}`).addEventListener('click', function(event) { const case_sensitive = event.target.checked + if (case_sensitive) { + document.getElementById(`preserve_case-${index}`).checked = false + } setKwargs(index, { case_sensitive }) }) + document.getElementById(`preserve_case-${index}`).addEventListener('click', function(event) { + const preserve_case = event.target.checked + if (preserve_case) { + document.getElementById(`case_sensitive-${index}`).checked = false + } + setKwargs(index, { preserve_case }) + }) + document.getElementById(`escape_special-${index}`).addEventListener('click', function(event) { const escape_special = event.target.checked setKwargs(index, { escape_special }) @@ -276,7 +295,7 @@ function createTable(index, data) { var size = 10; var dataObject = [] var varsObject = [] -var settingsObject = { 'include': true, 'rule_ordering': "as-written", 'case_sensitive': true, 'escape_special': false, 'reverse': false } +var settingsObject = { 'include': true, 'rule_ordering': "as-written", 'case_sensitive': true, 'preserve_case': false, 'escape_special': false, 'reverse': false } for (var j = 0; j < size; j++) { dataObject.push({ "in": '', @@ -339,6 +358,7 @@ getIncludedMappings = function() { var getKwargs = function(index) { const rule_ordering = $(`#rule_ordering-${index}`).val() const case_sensitive = document.getElementById(`case_sensitive-${index}`).checked + const preserve_case = document.getElementById(`preserve_case-${index}`).checked const escape_special = document.getElementById(`escape_special-${index}`).checked const reverse = document.getElementById(`reverse-${index}`).checked const include = document.getElementById(`include-${index}`).checked @@ -355,6 +375,7 @@ var getKwargs = function(index) { return { rule_ordering, case_sensitive, + preserve_case, escape_special, reverse, include, @@ -372,6 +393,9 @@ var setKwargs = function(index, kwargs) { if ('case_sensitive' in kwargs) { document.getElementById(`case_sensitive-${index}`).checked = kwargs['case_sensitive'] } + if ('preserve_case' in kwargs) { + document.getElementById(`preserve_case-${index}`).checked = kwargs['preserve_case'] + } if ('escape_special' in kwargs) { document.getElementById(`escape_special-${index}`).checked = kwargs['escape_special'] } diff --git a/g2p/tests/public/data/kwk.psv b/g2p/tests/public/data/kwk.psv index 5edd2894..6a87ca6b 100644 --- a/g2p/tests/public/data/kwk.psv +++ b/g2p/tests/public/data/kwk.psv @@ -11,10 +11,17 @@ kwk-boas|kwk-umista|g·āyaxalisē|gayax̱alisi kwk-boas|kwk-umista|x\u0323wēlaxᵋw\u1D07sdes|xwilax̱ʼwa̱sdis kwk-boas|kwk-umista|ăwŭnagwīsē ʟ̣ēg̣adēs|a̱wunagwisi dłig̱adis kwk-boas|kwk-umista|yîx ōmpas ōᵋmaxt!ālaʟēᵋyēxa|yix̱ umpas uʼmax̱t̓alatłiʼyix̱a -kwk-boas|kwk-umista|tsāg̣ᴇmas g·ōkwas Ts!ᴇxᵋēdē|tsag̱a̱mas gukwas Ts!a̱x̱ʼidi +kwk-boas|kwk-umista|tsāg̣ᴇmas g·ōkwas Ts!ᴇxᵋēdē|tsag̱a̱mas gukwas Tʼsa̱x̱ʼidi kwk-boas|kwk-umista|lāx̣wa ᵋnāx̣wax|laxwa ʼnaxwax̱ kwk-boas|kwk-umista|g·ig̣ŭmaᵋyasa ᵋnᴇᵋmēmotasa|gig̱umaʼyasa ʼna̱ʼmimutasa kwk-boas|kwk-umista|yîxs sēsᴇyūʟaēs|yix̱s sisa̱yutłaʼis kwk-napa|kwk-ipa|gam̓ən|ɡaʔmən kwk-napa|kwk-ipa|c̓ay̓ux̌ʷ|tʼsaʔyuχʷ kwk-napa|kwk-ipa|wəq̓ʷɛʔs|wəqʼʷɛʔs + +# Artificial data to test capitalization of kwk BOAS->Umista +kwk-boas|kwk-umista|TAtap!Aʟa|TAtap̓Atła +# A real word, capitalized +kwk-boas|kwk-umista|G·āyaxalisē|Gayax̱alisi +# This case not activated because it doesn't actually currently work +#kwk-boas|kwk-umista|TᴇAtᴇapʟ!Aʟa|TA̱ʼAta̱ʼaptʼłAtła diff --git a/g2p/tests/test_cli.py b/g2p/tests/test_cli.py index e369dd50..8656e601 100755 --- a/g2p/tests/test_cli.py +++ b/g2p/tests/test_cli.py @@ -143,6 +143,7 @@ def test_convert(self): out_lang, word_to_convert, tok_option, + reference_string, ) error_count += 1 @@ -152,6 +153,7 @@ def test_convert(self): out_lang, word_to_convert, tok_option, + reference_string, ) = first_failed_test output_string = self.runner.invoke( convert, @@ -160,8 +162,7 @@ def test_convert(self): self.assertEqual( output_string, reference_string.strip(), - f"{in_lang}->{out_lang} mapping error " - "for '{word_to_convert}'.\n" + f"{in_lang}->{out_lang} mapping error for '{word_to_convert}'.\n" "Look for warnings in the log for any more mapping errors", ) diff --git a/g2p/tests/test_mappings.py b/g2p/tests/test_mappings.py index 63aaedf0..dc91f992 100755 --- a/g2p/tests/test_mappings.py +++ b/g2p/tests/test_mappings.py @@ -194,6 +194,10 @@ def test_case_sensitive(self): self.assertEqual(transducer_case_sensitive("a").output_string, "a") self.assertEqual(transducer("A").output_string, "b") + def test_case_equivalencies(self): + with self.assertRaises(exceptions.MalformedMapping): + Mapping(rules=[{"in": "a", "out": "b"}], case_equivalencies={"a": "AA"}) + def test_escape_special(self): mapping = Mapping(rules=[{"in": r"\d", "out": "digit"}]) mapping_escaped = Mapping( diff --git a/g2p/tests/test_transducer.py b/g2p/tests/test_transducer.py index 73b61904..37a012bc 100755 --- a/g2p/tests/test_transducer.py +++ b/g2p/tests/test_transducer.py @@ -3,6 +3,7 @@ import os from unittest import TestCase, main +from g2p.exceptions import MalformedMapping from g2p.mappings import Mapping from g2p.tests.public import PUBLIC_DIR from g2p.transducer import CompositeTransducer, Transducer, normalize_edges @@ -218,6 +219,45 @@ def test_deletion(self): self.assertEqual(self.test_deletion_transducer_csv("a").output_string, "") self.assertEqual(self.test_deletion_transducer_json("a").output_string, "") + def test_case_preservation(self): + mapping = Mapping( + rules=[ + {"in": "'a", "out": "b"}, + {"in": "e\u0301", "out": "f"}, + {"in": "tl", "out": "λ"}, + ], + case_sensitive=False, + preserve_case=True, + norm_form="NFC", + case_equivalencies={"λ": "\u2144"}, + ) + transducer = Transducer(mapping) + self.assertEqual(transducer("'a").output_string, "b") + self.assertEqual(transducer("'A").output_string, "B") + self.assertEqual(transducer("E\u0301").output_string, "F") + self.assertEqual(transducer("e\u0301").output_string, "f") + # Test what happens in Heiltsuk. \u03BB (λ) should be capitalized as \u2144 (⅄) + self.assertEqual(transducer("TLaba").output_string, "\u2144aba") + self.assertEqual(transducer("tlaba").output_string, "λaba") + # I guess it's arguable what should happen here, but I'll just change case if any of the characters are differently cased + self.assertEqual(transducer("Tlaba").output_string, "\u2144aba") + # case equivalencies that are not the same length cause indexing errors in the current implementation + with self.assertRaises(MalformedMapping): + Mapping( + rules=[ + {"in": "'a", "out": "b"}, + {"in": "e\u0301", "out": "f"}, + {"in": "tl", "out": "λ"}, + ], + case_sensitive=False, + preserve_case=True, + norm_form="NFC", + case_equivalencies={"λ": "\u2144\u2144\u2144"}, + ) + + with self.assertRaises(MalformedMapping): + _ = Mapping(rules=[], case_sensitive=True, preserve_case=True) + def test_normalize_edges(self): # Remove non-deletion edges with the same index as deletions bad_edges = [ diff --git a/g2p/transducer/__init__.py b/g2p/transducer/__init__.py index e3ec7923..45bb99ef 100644 --- a/g2p/transducer/__init__.py +++ b/g2p/transducer/__init__.py @@ -420,6 +420,7 @@ class Transducer: def __init__(self, mapping: Mapping): self.mapping = mapping self.case_sensitive = mapping.case_sensitive + self.preserve_case = mapping.preserve_case self.norm_form = mapping.norm_form self.out_delimiter = mapping.out_delimiter self._index_match_pattern = re.compile(r"(?<={)\d+(?=})") @@ -428,7 +429,7 @@ def __init__(self, mapping: Mapping): def __repr__(self): return f"{self.__class__} between {self.mapping.in_lang} and {self.mapping.out_lang}" - def __call__(self, to_convert: str, index: bool = False, debugger: bool = False): + def __call__(self, to_convert: str): """The basic method to transduce an input. A proxy for self.apply_rules. Args: @@ -439,7 +440,11 @@ def __call__(self, to_convert: str, index: bool = False, debugger: bool = False) and output characters and their corresponding edges representing the indices of the transformation. """ - return self.apply_rules(to_convert) + tg = self.apply_rules(to_convert) + if self.preserve_case: + return preserve_case(tg, self.mapping.case_equivalencies) + else: + return tg @staticmethod def _pua_to_index(string: str) -> int: @@ -1257,3 +1262,53 @@ def check(self, tg: TransductionGraph, shallow=False, display_warnings=False): else: return False return result + + +def preserve_case( + tg: TransductionGraph, case_equivalencies: Dict[str, str] = None +) -> TransductionGraph: + if case_equivalencies is None: + case_equivalencies = {} + reverse_case_equivalencies = {v: k for k, v in case_equivalencies.items()} + all_lower_case_equivalencies = case_equivalencies.keys() + all_upper_case_equivalencies = case_equivalencies.values() + new_string = "" + for item in tg.substring_alignments(): + in_sub = item[0] + out_sub = item[1] + any_in_upper = any(x.isupper() for x in in_sub) + any_in_lower = any(x.islower() for x in in_sub) + any_out_upper = any(x.isupper() for x in out_sub) + any_out_lower = any(x.islower() for x in out_sub) + # continue if character is un-caseable + if ( + out_sub not in case_equivalencies + and not any_out_upper + and not any_out_lower + ): + new_string += out_sub + continue + # lower case using case equivalencies if they exist + if ( + any_in_lower or in_sub in all_lower_case_equivalencies + ) and out_sub in all_upper_case_equivalencies: + new_string += reverse_case_equivalencies[out_sub] + continue + # upper case using case equivalencies if they exist + elif ( + any_in_upper or in_sub in all_upper_case_equivalencies + ) and out_sub in all_lower_case_equivalencies: + new_string += case_equivalencies[out_sub] + continue + # change to upper if required + if any_in_upper and any_out_lower: + new_string += out_sub.upper() + continue + # change to lower if required + if any_in_lower and any_out_upper: + new_string += out_sub.lower() + continue + # just in case, append the out_sub + new_string += out_sub + tg.output_string = new_string + return tg