Skip to content

Commit

Permalink
Merge pull request #274 from roedoejet/dev.casing
Browse files Browse the repository at this point in the history
merge preserve_case functionality into main
  • Loading branch information
joanise authored Nov 15, 2023
2 parents d1aa6dd + c31c66b commit daf33b8
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 8 deletions.
2 changes: 2 additions & 0 deletions g2p/mappings/langs/kwk/config-g2p.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ mappings:
out_lang: kwk-umista
rule_ordering: apply-longest-first
prevent_feeding: true
case_sensitive: false
preserve_case: true
authors:
- Fineen Davis
- Olivia Chen
Expand Down
Binary file modified g2p/mappings/langs/langs.json.gz
Binary file not shown.
31 changes: 29 additions & 2 deletions g2p/mappings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,13 @@ class _MappingModelDefinition(BaseModel):
"""Deprecated: Please use rule_ordering='as_written' """

case_sensitive: bool = True
"""Lower all rules and conversion input"""
"""When false, lowercase all rules and conversion input"""

case_equivalencies: dict = {}
"""List of case equivalencies for preserve_case that are not already in the Unicode standard"""

preserve_case: bool = False
"""Preserve source case in output (requires case_sensitive=False)"""

escape_special: bool = False
"""Escape special characters in rules"""
Expand Down Expand Up @@ -725,7 +731,7 @@ def check_mapping_types(self) -> "_MappingModelDefinition":
and not self.rules
and self.rules_path is None
):
LOGGER.warn(
LOGGER.warning(
exceptions.MalformedMapping(
"You have to either specify some rules or a path to a file containing rules."
)
Expand Down Expand Up @@ -755,6 +761,27 @@ def validate_norm_form(cls, v):
v = "none"
return v

@field_validator("case_equivalencies", mode="before")
@classmethod
def validate_case_equivalencies(cls, v):
if not v or v is None:
v = {}
for lower_case, upper_case in v.items():
if len(lower_case) != len(upper_case):
raise exceptions.MalformedMapping(
f"Sorry, the case equivalency between {lower_case} and {upper_case} is not valid because it is not the same length, please write rules such that any case equivalent is of equal length."
)
return v

@model_validator(mode="after")
def validate_preserve_case(self):
"""preserve_case=True requires case_sensitive=False"""
if self.preserve_case and self.case_sensitive:
raise exceptions.MalformedMapping(
"Sorry, preserve_case=True requires case_sensitive=False."
)
return self

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("rules_path", "abbreviations_path", "alignments_path", pre=True)
Expand Down
26 changes: 25 additions & 1 deletion g2p/static/custom.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function createSettings(index, data) {
let include = 'checked';
let rule_ordering = '';
let case_sensitive = '';
let preserve_case = '';
let escape_special = '';
let reverse = '';
let active = '';
Expand All @@ -76,6 +77,9 @@ function createSettings(index, data) {
if (data['case_sensitive']) {
case_sensitive = 'checked'
}
if (data['preserve_case']) {
preserve_case = 'checked'
}
if (data['escape_special']) {
escape_special = 'checked'
}
Expand Down Expand Up @@ -115,6 +119,10 @@ function createSettings(index, data) {
value='case_sensitive'>
<label for='case_sensitive'>Rules are case sensitive</label>
</div>
<div>
<input ${preserve_case} id='preserve_case-${index}' type='checkbox' name='preserve_case' value='preserve_case'>
<label for='preserve_case'>Preserve input case in output</label>
</div>
<div>
<input ${escape_special} id='escape_special-${index}' type='checkbox' name='escape_special' value='escape_special'>
<label for='escape_special'>Escape special characters</label>
Expand Down Expand Up @@ -161,9 +169,20 @@ function createSettings(index, data) {

document.getElementById(`case_sensitive-${index}`).addEventListener('click', function(event) {
const case_sensitive = event.target.checked
if (case_sensitive) {
document.getElementById(`preserve_case-${index}`).checked = false
}
setKwargs(index, { case_sensitive })
})

document.getElementById(`preserve_case-${index}`).addEventListener('click', function(event) {
const preserve_case = event.target.checked
if (preserve_case) {
document.getElementById(`case_sensitive-${index}`).checked = false
}
setKwargs(index, { preserve_case })
})

document.getElementById(`escape_special-${index}`).addEventListener('click', function(event) {
const escape_special = event.target.checked
setKwargs(index, { escape_special })
Expand Down Expand Up @@ -276,7 +295,7 @@ function createTable(index, data) {
var size = 10;
var dataObject = []
var varsObject = []
var settingsObject = { 'include': true, 'rule_ordering': "as-written", 'case_sensitive': true, 'escape_special': false, 'reverse': false }
var settingsObject = { 'include': true, 'rule_ordering': "as-written", 'case_sensitive': true, 'preserve_case': false, 'escape_special': false, 'reverse': false }
for (var j = 0; j < size; j++) {
dataObject.push({
"in": '',
Expand Down Expand Up @@ -339,6 +358,7 @@ getIncludedMappings = function() {
var getKwargs = function(index) {
const rule_ordering = $(`#rule_ordering-${index}`).val()
const case_sensitive = document.getElementById(`case_sensitive-${index}`).checked
const preserve_case = document.getElementById(`preserve_case-${index}`).checked
const escape_special = document.getElementById(`escape_special-${index}`).checked
const reverse = document.getElementById(`reverse-${index}`).checked
const include = document.getElementById(`include-${index}`).checked
Expand All @@ -355,6 +375,7 @@ var getKwargs = function(index) {
return {
rule_ordering,
case_sensitive,
preserve_case,
escape_special,
reverse,
include,
Expand All @@ -372,6 +393,9 @@ var setKwargs = function(index, kwargs) {
if ('case_sensitive' in kwargs) {
document.getElementById(`case_sensitive-${index}`).checked = kwargs['case_sensitive']
}
if ('preserve_case' in kwargs) {
document.getElementById(`preserve_case-${index}`).checked = kwargs['preserve_case']
}
if ('escape_special' in kwargs) {
document.getElementById(`escape_special-${index}`).checked = kwargs['escape_special']
}
Expand Down
9 changes: 8 additions & 1 deletion g2p/tests/public/data/kwk.psv
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ kwk-boas|kwk-umista|g·āyaxalisē|gayax̱alisi
kwk-boas|kwk-umista|x\u0323wēlaxᵋw\u1D07sdes|xwilax̱ʼwa̱sdis
kwk-boas|kwk-umista|ăwŭnagwīsē ʟ̣ēg̣adēs|a̱wunagwisi dłig̱adis
kwk-boas|kwk-umista|yîx ōmpas ōᵋmaxt!ālaʟēᵋyēxa|yix̱ umpas uʼmax̱t̓alatłiʼyix̱a
kwk-boas|kwk-umista|tsāg̣ᴇmas g·ōkwas Ts!ᴇxᵋēdē|tsag̱a̱mas gukwas Ts!a̱x̱ʼidi
kwk-boas|kwk-umista|tsāg̣ᴇmas g·ōkwas Ts!ᴇxᵋēdē|tsag̱a̱mas gukwas Tʼsa̱x̱ʼidi
kwk-boas|kwk-umista|lāx̣wa ᵋnāx̣wax|laxwa ʼnaxwax̱
kwk-boas|kwk-umista|g·ig̣ŭmaᵋyasa ᵋnᴇᵋmēmotasa|gig̱umaʼyasa ʼna̱ʼmimutasa
kwk-boas|kwk-umista|yîxs sēsᴇyūʟaēs|yix̱s sisa̱yutłaʼis
kwk-napa|kwk-ipa|gam̓ən|ɡaʔmən
kwk-napa|kwk-ipa|c̓ay̓ux̌ʷ|tʼsaʔyuχʷ
kwk-napa|kwk-ipa|wəq̓ʷɛʔs|wəqʼʷɛʔs

# Artificial data to test capitalization of kwk BOAS->Umista
kwk-boas|kwk-umista|TAtap!Aʟa|TAtap̓Atła
# A real word, capitalized
kwk-boas|kwk-umista|G·āyaxalisē|Gayax̱alisi
# This case not activated because it doesn't actually currently work
#kwk-boas|kwk-umista|TᴇAtᴇapʟ!Aʟa|TA̱ʼAta̱ʼaptʼłAtła
5 changes: 3 additions & 2 deletions g2p/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_convert(self):
out_lang,
word_to_convert,
tok_option,
reference_string,
)
error_count += 1

Expand All @@ -152,6 +153,7 @@ def test_convert(self):
out_lang,
word_to_convert,
tok_option,
reference_string,
) = first_failed_test
output_string = self.runner.invoke(
convert,
Expand All @@ -160,8 +162,7 @@ def test_convert(self):
self.assertEqual(
output_string,
reference_string.strip(),
f"{in_lang}->{out_lang} mapping error "
"for '{word_to_convert}'.\n"
f"{in_lang}->{out_lang} mapping error for '{word_to_convert}'.\n"
"Look for warnings in the log for any more mapping errors",
)

Expand Down
4 changes: 4 additions & 0 deletions g2p/tests/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def test_case_sensitive(self):
self.assertEqual(transducer_case_sensitive("a").output_string, "a")
self.assertEqual(transducer("A").output_string, "b")

def test_case_equivalencies(self):
with self.assertRaises(exceptions.MalformedMapping):
Mapping(rules=[{"in": "a", "out": "b"}], case_equivalencies={"a": "AA"})

def test_escape_special(self):
mapping = Mapping(rules=[{"in": r"\d", "out": "digit"}])
mapping_escaped = Mapping(
Expand Down
40 changes: 40 additions & 0 deletions g2p/tests/test_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from unittest import TestCase, main

from g2p.exceptions import MalformedMapping
from g2p.mappings import Mapping
from g2p.tests.public import PUBLIC_DIR
from g2p.transducer import CompositeTransducer, Transducer, normalize_edges
Expand Down Expand Up @@ -218,6 +219,45 @@ def test_deletion(self):
self.assertEqual(self.test_deletion_transducer_csv("a").output_string, "")
self.assertEqual(self.test_deletion_transducer_json("a").output_string, "")

def test_case_preservation(self):
mapping = Mapping(
rules=[
{"in": "'a", "out": "b"},
{"in": "e\u0301", "out": "f"},
{"in": "tl", "out": "λ"},
],
case_sensitive=False,
preserve_case=True,
norm_form="NFC",
case_equivalencies={"λ": "\u2144"},
)
transducer = Transducer(mapping)
self.assertEqual(transducer("'a").output_string, "b")
self.assertEqual(transducer("'A").output_string, "B")
self.assertEqual(transducer("E\u0301").output_string, "F")
self.assertEqual(transducer("e\u0301").output_string, "f")
# Test what happens in Heiltsuk. \u03BB (λ) should be capitalized as \u2144 (⅄)
self.assertEqual(transducer("TLaba").output_string, "\u2144aba")
self.assertEqual(transducer("tlaba").output_string, "λaba")
# I guess it's arguable what should happen here, but I'll just change case if any of the characters are differently cased
self.assertEqual(transducer("Tlaba").output_string, "\u2144aba")
# case equivalencies that are not the same length cause indexing errors in the current implementation
with self.assertRaises(MalformedMapping):
Mapping(
rules=[
{"in": "'a", "out": "b"},
{"in": "e\u0301", "out": "f"},
{"in": "tl", "out": "λ"},
],
case_sensitive=False,
preserve_case=True,
norm_form="NFC",
case_equivalencies={"λ": "\u2144\u2144\u2144"},
)

with self.assertRaises(MalformedMapping):
_ = Mapping(rules=[], case_sensitive=True, preserve_case=True)

def test_normalize_edges(self):
# Remove non-deletion edges with the same index as deletions
bad_edges = [
Expand Down
59 changes: 57 additions & 2 deletions g2p/transducer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ class Transducer:
def __init__(self, mapping: Mapping):
self.mapping = mapping
self.case_sensitive = mapping.case_sensitive
self.preserve_case = mapping.preserve_case
self.norm_form = mapping.norm_form
self.out_delimiter = mapping.out_delimiter
self._index_match_pattern = re.compile(r"(?<={)\d+(?=})")
Expand All @@ -428,7 +429,7 @@ def __init__(self, mapping: Mapping):
def __repr__(self):
return f"{self.__class__} between {self.mapping.in_lang} and {self.mapping.out_lang}"

def __call__(self, to_convert: str, index: bool = False, debugger: bool = False):
def __call__(self, to_convert: str):
"""The basic method to transduce an input. A proxy for self.apply_rules.
Args:
Expand All @@ -439,7 +440,11 @@ def __call__(self, to_convert: str, index: bool = False, debugger: bool = False)
and output characters and their corresponding edges representing the indices
of the transformation.
"""
return self.apply_rules(to_convert)
tg = self.apply_rules(to_convert)
if self.preserve_case:
return preserve_case(tg, self.mapping.case_equivalencies)
else:
return tg

@staticmethod
def _pua_to_index(string: str) -> int:
Expand Down Expand Up @@ -1257,3 +1262,53 @@ def check(self, tg: TransductionGraph, shallow=False, display_warnings=False):
else:
return False
return result


def preserve_case(
tg: TransductionGraph, case_equivalencies: Dict[str, str] = None
) -> TransductionGraph:
if case_equivalencies is None:
case_equivalencies = {}
reverse_case_equivalencies = {v: k for k, v in case_equivalencies.items()}
all_lower_case_equivalencies = case_equivalencies.keys()
all_upper_case_equivalencies = case_equivalencies.values()
new_string = ""
for item in tg.substring_alignments():
in_sub = item[0]
out_sub = item[1]
any_in_upper = any(x.isupper() for x in in_sub)
any_in_lower = any(x.islower() for x in in_sub)
any_out_upper = any(x.isupper() for x in out_sub)
any_out_lower = any(x.islower() for x in out_sub)
# continue if character is un-caseable
if (
out_sub not in case_equivalencies
and not any_out_upper
and not any_out_lower
):
new_string += out_sub
continue
# lower case using case equivalencies if they exist
if (
any_in_lower or in_sub in all_lower_case_equivalencies
) and out_sub in all_upper_case_equivalencies:
new_string += reverse_case_equivalencies[out_sub]
continue
# upper case using case equivalencies if they exist
elif (
any_in_upper or in_sub in all_upper_case_equivalencies
) and out_sub in all_lower_case_equivalencies:
new_string += case_equivalencies[out_sub]
continue
# change to upper if required
if any_in_upper and any_out_lower:
new_string += out_sub.upper()
continue
# change to lower if required
if any_in_lower and any_out_upper:
new_string += out_sub.lower()
continue
# just in case, append the out_sub
new_string += out_sub
tg.output_string = new_string
return tg

0 comments on commit daf33b8

Please sign in to comment.