5
5
"""
6
6
7
7
import operator
8
+ import difflib
8
9
9
- from pycorrector .utils .text_utils import is_chinese_char
10
10
11
-
12
- def get_errors_for_diff_length (corrected_text , origin_text ):
11
+ def get_errors (corrected_text , origin_text ):
13
12
"""Get errors between corrected text and origin text"""
14
- new_corrected_text = ""
15
13
errors = []
16
- i , j = 0 , 0
17
14
unk_tokens = [' ' , '“' , '”' , '‘' , '’' , '琊' , '\n ' , '…' , '擤' , '\t ' , '玕' , '' ]
18
15
19
- while i < len (origin_text ) and j < len (corrected_text ):
20
- if origin_text [i ] in unk_tokens :
21
- new_corrected_text += origin_text [i ]
22
- i += 1
23
- elif corrected_text [j ] in unk_tokens :
24
- new_corrected_text += corrected_text [j ]
25
- j += 1
26
- # Deal with Chinese characters
27
- elif is_chinese_char (origin_text [i ]) and is_chinese_char (corrected_text [j ]):
28
- # If the two characters are the same, then the two pointers move forward together
29
- if origin_text [i ] == corrected_text [j ]:
16
+ s = difflib .SequenceMatcher (None , origin_text , corrected_text )
17
+ new_corrected_text = ""
18
+ for tag , i1 , i2 , j1 , j2 in s .get_opcodes ():
19
+ if tag == 'replace' :
20
+ for i , j in zip (range (i1 , i2 ), range (j1 , j2 )):
21
+ if origin_text [i ] not in unk_tokens and corrected_text [j ] not in unk_tokens :
22
+ errors .append ((origin_text [i ], corrected_text [j ], i ))
30
23
new_corrected_text += corrected_text [j ]
31
- i += 1
32
- j += 1
33
- else :
34
- # Check for insertion errors
35
- if j + 1 < len (corrected_text ) and origin_text [i ] == corrected_text [j + 1 ]:
36
- errors .append (('' , corrected_text [j ], j ))
37
- new_corrected_text += corrected_text [j ]
38
- j += 1
39
- # Check for deletion errors
40
- elif i + 1 < len (origin_text ) and origin_text [i + 1 ] == corrected_text [j ]:
24
+ elif tag == 'delete' :
25
+ for i in range (i1 , i2 ):
26
+ if origin_text [i ] not in unk_tokens :
41
27
errors .append ((origin_text [i ], '' , i ))
42
- i += 1
43
- # Check for replacement errors
44
- else :
45
- errors .append ((origin_text [i ], corrected_text [j ], i ))
46
- new_corrected_text += corrected_text [j ]
47
- i += 1
48
- j += 1
49
- else :
50
- new_corrected_text += origin_text [i ]
51
- if origin_text [i ] == corrected_text [j ]:
52
- j += 1
53
- i += 1
54
- errors = sorted (errors , key = operator .itemgetter (2 ))
55
- return corrected_text , errors
56
-
57
-
58
- def get_errors_for_same_length (corrected_text , origin_text ):
59
- """Get new corrected text and errors between corrected text and origin text"""
60
- errors = []
61
- unk_tokens = [' ' , '“' , '”' , '‘' , '’' , '琊' , '\n ' , '…' , '擤' , '\t ' , '玕' , '' ]
28
+ new_corrected_text += origin_text [i ]
29
+ elif tag == 'insert' :
30
+ for j in range (j1 , j2 ):
31
+ if corrected_text [j ] not in unk_tokens :
32
+ errors .append (('' , corrected_text [j ], j ))
33
+ new_corrected_text += corrected_text [j ]
34
+ elif tag == 'equal' :
35
+ new_corrected_text += origin_text [i1 :i2 ]
62
36
63
- for i , ori_char in enumerate (origin_text ):
64
- if i >= len (corrected_text ):
65
- continue
66
- if ori_char in unk_tokens :
67
- # deal with unk word
68
- corrected_text = corrected_text [:i ] + ori_char + corrected_text [i :]
69
- continue
70
- if ori_char != corrected_text [i ]:
71
- if not is_chinese_char (ori_char ):
72
- # pass not chinese char
73
- corrected_text = corrected_text [:i ] + ori_char + corrected_text [i + 1 :]
74
- continue
75
- if not is_chinese_char (corrected_text [i ]):
76
- corrected_text = corrected_text [:i ] + corrected_text [i + 1 :]
77
- continue
78
- errors .append ((ori_char , corrected_text [i ], i ))
79
- errors = sorted (errors , key = operator .itemgetter (2 ))
80
- return corrected_text , errors
37
+ errors = sorted (errors , key = lambda x : x [2 ])
38
+ return new_corrected_text , errors
81
39
82
40
83
41
if __name__ == '__main__' :
@@ -100,5 +58,6 @@ def get_errors_for_same_length(corrected_text, origin_text):
100
58
('我喜欢吃鸡,公鸡、母鸡、白切鸡、乌鸡、紫燕鸡' , '我喜欢吃鸡,公鸡、母鸡、切鸡、乌鸡、紫燕鸡' ), # 少字
101
59
]
102
60
for pair in sentence_pairs :
103
- new_corrected_text , errors = get_errors_for_same_length (pair [0 ], pair [1 ])
61
+ new_corrected_text , errors = get_errors (pair [0 ], pair [1 ])
104
62
print (f"{ new_corrected_text } { errors } " )
63
+ print ('--' * 42 + '\n ' )
0 commit comments