Skip to content

Commit 937c4b5

Browse files
committed
update get errors.
1 parent e1f6db9 commit 937c4b5

File tree

5 files changed

+32
-73
lines changed

5 files changed

+32
-73
lines changed

pycorrector/gpt/gpt_corrector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
sys.path.append('../..')
1414
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
1515
from pycorrector.gpt.gpt_model import GptModel
16-
from pycorrector.utils.error_utils import get_errors_for_diff_length
16+
from pycorrector.utils.error_utils import get_errors
1717

1818

1919
class GptCorrector(GptModel):
@@ -87,7 +87,7 @@ def correct_batch(
8787
new_corrected_sentences = []
8888
corrected_details = []
8989
for idx, corrected_sent in enumerate(corrected_sentences):
90-
new_corrected_sent, sub_details = get_errors_for_diff_length(corrected_sent, sentences[idx])
90+
new_corrected_sent, sub_details = get_errors(corrected_sent, sentences[idx])
9191
new_corrected_sentences.append(new_corrected_sent)
9292
corrected_details.append(sub_details)
9393
return [{'source': s, 'target': c, 'errors': e} for s, c, e in

pycorrector/macbert/macbert_corrector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
sys.path.append('../..')
1717
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
18-
from pycorrector.utils.error_utils import get_errors_for_same_length
18+
from pycorrector.utils.error_utils import get_errors
1919

2020
device = torch.device("mps" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
2121
else "cuda" if torch.cuda.is_available() else "cpu")
@@ -109,7 +109,7 @@ def correct_batch(
109109
new_corrected_sentences = []
110110
corrected_details = []
111111
for idx, corrected_sent in enumerate(corrected_sentences):
112-
new_corrected_sent, sub_details = get_errors_for_same_length(corrected_sent, sentences[idx])
112+
new_corrected_sent, sub_details = get_errors(corrected_sent, sentences[idx])
113113
new_corrected_sentences.append(new_corrected_sent)
114114
corrected_details.append(sub_details)
115115
return [{'source': s, 'target': c, 'errors': e} for s, c, e in

pycorrector/seq2seq/conv_seq2seq_corrector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
1818
from pycorrector.utils.get_file import get_file
1919
from pycorrector.detector import USER_DATA_DIR
20-
from pycorrector.utils.error_utils import get_errors_for_diff_length
20+
from pycorrector.utils.error_utils import get_errors
2121

2222
device = torch.device("mps" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
2323
else "cuda" if torch.cuda.is_available() else "cpu")
@@ -85,7 +85,7 @@ def correct_batch(self, sentences: List[str], max_length: int = 128, silent: boo
8585
new_corrected_sentences = []
8686
corrected_details = []
8787
for idx, corrected_sent in enumerate(corrected_sentences):
88-
new_corrected_sent, sub_details = get_errors_for_diff_length(corrected_sent, sentences[idx])
88+
new_corrected_sent, sub_details = get_errors(corrected_sent, sentences[idx])
8989
new_corrected_sentences.append(new_corrected_sent)
9090
corrected_details.append(sub_details)
9191
return [{'source': s, 'target': c, 'errors': e} for s, c, e in

pycorrector/t5/t5_corrector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
sys.path.append('../..')
1717
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
18-
from pycorrector.utils.error_utils import get_errors_for_same_length
18+
from pycorrector.utils.error_utils import get_errors
1919

2020
device = torch.device("mps" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
2121
else "cuda" if torch.cuda.is_available() else "cpu")
@@ -83,7 +83,7 @@ def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size:
8383
new_corrected_sentences = []
8484
corrected_details = []
8585
for idx, corrected_sent in enumerate(corrected_sentences):
86-
new_corrected_sent, sub_details = get_errors_for_same_length(corrected_sent, sentences[idx])
86+
new_corrected_sent, sub_details = get_errors(corrected_sent, sentences[idx])
8787
new_corrected_sentences.append(new_corrected_sent)
8888
corrected_details.append(sub_details)
8989
return [{'source': s, 'target': c, 'errors': e} for s, c, e in

pycorrector/utils/error_utils.py

+24-65
Original file line numberDiff line numberDiff line change
@@ -5,79 +5,37 @@
55
"""
66

77
import operator
8+
import difflib
89

9-
from pycorrector.utils.text_utils import is_chinese_char
1010

11-
12-
def get_errors_for_diff_length(corrected_text, origin_text):
11+
def get_errors(corrected_text, origin_text):
1312
"""Get errors between corrected text and origin text"""
14-
new_corrected_text = ""
1513
errors = []
16-
i, j = 0, 0
1714
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
1815

19-
while i < len(origin_text) and j < len(corrected_text):
20-
if origin_text[i] in unk_tokens:
21-
new_corrected_text += origin_text[i]
22-
i += 1
23-
elif corrected_text[j] in unk_tokens:
24-
new_corrected_text += corrected_text[j]
25-
j += 1
26-
# Deal with Chinese characters
27-
elif is_chinese_char(origin_text[i]) and is_chinese_char(corrected_text[j]):
28-
# If the two characters are the same, then the two pointers move forward together
29-
if origin_text[i] == corrected_text[j]:
16+
s = difflib.SequenceMatcher(None, origin_text, corrected_text)
17+
new_corrected_text = ""
18+
for tag, i1, i2, j1, j2 in s.get_opcodes():
19+
if tag == 'replace':
20+
for i, j in zip(range(i1, i2), range(j1, j2)):
21+
if origin_text[i] not in unk_tokens and corrected_text[j] not in unk_tokens:
22+
errors.append((origin_text[i], corrected_text[j], i))
3023
new_corrected_text += corrected_text[j]
31-
i += 1
32-
j += 1
33-
else:
34-
# Check for insertion errors
35-
if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
36-
errors.append(('', corrected_text[j], j))
37-
new_corrected_text += corrected_text[j]
38-
j += 1
39-
# Check for deletion errors
40-
elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
24+
elif tag == 'delete':
25+
for i in range(i1, i2):
26+
if origin_text[i] not in unk_tokens:
4127
errors.append((origin_text[i], '', i))
42-
i += 1
43-
# Check for replacement errors
44-
else:
45-
errors.append((origin_text[i], corrected_text[j], i))
46-
new_corrected_text += corrected_text[j]
47-
i += 1
48-
j += 1
49-
else:
50-
new_corrected_text += origin_text[i]
51-
if origin_text[i] == corrected_text[j]:
52-
j += 1
53-
i += 1
54-
errors = sorted(errors, key=operator.itemgetter(2))
55-
return corrected_text, errors
56-
57-
58-
def get_errors_for_same_length(corrected_text, origin_text):
59-
"""Get new corrected text and errors between corrected text and origin text"""
60-
errors = []
61-
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
28+
new_corrected_text += origin_text[i]
29+
elif tag == 'insert':
30+
for j in range(j1, j2):
31+
if corrected_text[j] not in unk_tokens:
32+
errors.append(('', corrected_text[j], j))
33+
new_corrected_text += corrected_text[j]
34+
elif tag == 'equal':
35+
new_corrected_text += origin_text[i1:i2]
6236

63-
for i, ori_char in enumerate(origin_text):
64-
if i >= len(corrected_text):
65-
continue
66-
if ori_char in unk_tokens:
67-
# deal with unk word
68-
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
69-
continue
70-
if ori_char != corrected_text[i]:
71-
if not is_chinese_char(ori_char):
72-
# pass not chinese char
73-
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
74-
continue
75-
if not is_chinese_char(corrected_text[i]):
76-
corrected_text = corrected_text[:i] + corrected_text[i + 1:]
77-
continue
78-
errors.append((ori_char, corrected_text[i], i))
79-
errors = sorted(errors, key=operator.itemgetter(2))
80-
return corrected_text, errors
37+
errors = sorted(errors, key=lambda x: x[2])
38+
return new_corrected_text, errors
8139

8240

8341
if __name__ == '__main__':
@@ -100,5 +58,6 @@ def get_errors_for_same_length(corrected_text, origin_text):
10058
('我喜欢吃鸡,公鸡、母鸡、白切鸡、乌鸡、紫燕鸡', '我喜欢吃鸡,公鸡、母鸡、切鸡、乌鸡、紫燕鸡'), # 少字
10159
]
10260
for pair in sentence_pairs:
103-
new_corrected_text, errors = get_errors_for_same_length(pair[0], pair[1])
61+
new_corrected_text, errors = get_errors(pair[0], pair[1])
10462
print(f"{new_corrected_text} {errors}")
63+
print('--' * 42 + '\n')

0 commit comments

Comments
 (0)