From 627f404dc0154e0671e30276e523dfc9850670fe Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 20:09:04 +0800 Subject: [PATCH 1/6] Fix: Solve the misalignment problem on UNK tokens. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit for the last example: `'你说:“怎么办?”我怎么知道?'`. before this commit: `[[('怎', '「', 4, 5), ('么', '怎', 5, 6), ('办', '么', 6, 7), ('?', '办', 7, 8), ('我', '?', 9, 10), ('怎', '」', 10, 11), ('么', '我', 11, 12), ('知', '怎', 12, 13), ('道', '么', 13, 14), ('?', '知', 14, 15)]]` after this commit: `[[]]` --- pycorrector/macbert/infer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 7d1f639d..0ac7d19d 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """ -@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com), okcd00(okcd00@qq.com) @description: """ import sys -import operator import torch +import operator +from glob import glob from transformers import BertTokenizer sys.path.append('../..') @@ -25,6 +26,15 @@ def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt', logger.debug("device: {}".format(device)) self.tokenizer = BertTokenizer.from_pretrained(vocab_path) cfg.merge_from_file(cfg_path) + + if os.path.isdir(ckpt_path): + # ckpt_path is allowed to be a path to a directory. + # e.g., "output/macbert4csc/epoch=09-val_loss=0.01.ckpt" or + # "output/macbert4csc/" both are OK. + # automatically select the model file with the lowest loss. + ckpt_path = sorted(glob(f"{ckpt_path}/*.ckpt"), + key=lambda x: float(x[:-5].split('=')[-1]))[-1] + if 'macbert4csc' in cfg_path: self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path, cfg=cfg, @@ -80,7 +90,7 @@ def get_errors(corrected_text, origin_text): for i, ori_char in enumerate(origin_text): if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: # add unk word - corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] + corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] continue if i >= len(corrected_text): continue From e96bbac8578cb21c2afee2f39a50024066affe3f Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 20:47:21 +0800 Subject: [PATCH 2/6] Update: handle the issues of blanks and UNKs. The handling of inserted spaces and unk characters should be addressed separately. --- pycorrector/macbert/infer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 0ac7d19d..9a21f291 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -88,9 +88,13 @@ def predict_with_error_detail(self, sentence_list): def get_errors(corrected_text, origin_text): sub_details = [] for i, ori_char in enumerate(origin_text): - if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: + if ori_char == " ": + # add blank word + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i:] + continue + if ori_char in ['“', '”', '‘', '’', '\n', '…', '—']: # add unk word - corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] continue if i >= len(corrected_text): continue @@ -121,6 +125,7 @@ def get_errors(corrected_text, origin_text): inputs = [ '它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。', '老是较书。', + '少先队 员因该 为老人让 坐', '感谢等五分以后,碰到一位很棒的奴生跟我可聊。', '遇到一位很棒的奴生跟我聊天。', '遇到一位很美的女生跟我疗天。', From 826b46fa4a5ede7709f1c4e9b996d8f475bef697 Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 20:55:54 +0800 Subject: [PATCH 3/6] Update: consider about different situations in predict() --- pycorrector/macbert/infer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 9a21f291..208b6972 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -85,12 +85,13 @@ def predict_with_error_detail(self, sentence_list): sentence_list = [sentence_list] corrected_texts = self.model.predict(sentence_list) - def get_errors(corrected_text, origin_text): + def get_errors(corrected_text, origin_text, blank_cleaned=False): + # blank_cleaned means if the blanks in texts are cleaned in predict(). sub_details = [] for i, ori_char in enumerate(origin_text): if ori_char == " ": # add blank word - _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i:] + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blank_cleaned else i + 1:] continue if ori_char in ['“', '”', '‘', '’', '\n', '…', '—']: # add unk word From 39b71da50f7b96c4395a1233f6d9f8f437c87e1a Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 21:10:27 +0800 Subject: [PATCH 4/6] Fix: local var names. Sorry for faulty cherry-pick, I'll re-test this on pycorrector, not on my modified one. --- pycorrector/macbert/infer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 208b6972..4070b580 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -85,28 +85,29 @@ def predict_with_error_detail(self, sentence_list): sentence_list = [sentence_list] corrected_texts = self.model.predict(sentence_list) - def get_errors(corrected_text, origin_text, blank_cleaned=False): - # blank_cleaned means if the blanks in texts are cleaned in predict(). + def get_errors(_corrected_text, _origin_text): sub_details = [] - for i, ori_char in enumerate(origin_text): + for i, ori_char in enumerate(_origin_text): if ori_char == " ": # add blank word - _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blank_cleaned else i + 1:] + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i:] continue if ori_char in ['“', '”', '‘', '’', '\n', '…', '—']: # add unk word _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] continue - if i >= len(corrected_text): + if i >= len(_corrected_text): continue - if ori_char != corrected_text[i]: - if ori_char.lower() == corrected_text[i]: + if ori_char != _corrected_text[i]: + # print(ori_char, corrected_text[i]) + if (ori_char.lower() == _corrected_text[i]) or _corrected_text[i] == '֍': # pass english upper char - corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] continue - sub_details.append((ori_char, corrected_text[i], i, i + 1)) + sub_details.append((ori_char, _corrected_text[i], i, i + 1)) + # print(_corrected_text) sub_details = sorted(sub_details, key=operator.itemgetter(2)) - return corrected_text, sub_details + return _corrected_text, sub_details for corrected_text, text in zip(corrected_texts, sentence_list): corrected_text, sub_details = get_errors(corrected_text, text) From 3ca9c40803c840dd61edca013dbffa286d4ef034 Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 21:24:08 +0800 Subject: [PATCH 5/6] Update: consider if blanks are (not) cleaned. --- pycorrector/macbert/infer.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 4070b580..0af20adf 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -6,7 +6,6 @@ import sys import torch import operator -from glob import glob from transformers import BertTokenizer sys.path.append('../..') @@ -27,14 +26,6 @@ def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt', self.tokenizer = BertTokenizer.from_pretrained(vocab_path) cfg.merge_from_file(cfg_path) - if os.path.isdir(ckpt_path): - # ckpt_path is allowed to be a path to a directory. - # e.g., "output/macbert4csc/epoch=09-val_loss=0.01.ckpt" or - # "output/macbert4csc/" both are OK. - # automatically select the model file with the lowest loss. - ckpt_path = sorted(glob(f"{ckpt_path}/*.ckpt"), - key=lambda x: float(x[:-5].split('=')[-1]))[-1] - if 'macbert4csc' in cfg_path: self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path, cfg=cfg, @@ -85,12 +76,12 @@ def predict_with_error_detail(self, sentence_list): sentence_list = [sentence_list] corrected_texts = self.model.predict(sentence_list) - def get_errors(_corrected_text, _origin_text): + def get_errors(_corrected_text, _origin_text, blanks_cleaned=False): sub_details = [] for i, ori_char in enumerate(_origin_text): if ori_char == " ": # add blank word - _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i:] + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blanks_cleaned else i + 1:] continue if ori_char in ['“', '”', '‘', '’', '\n', '…', '—']: # add unk word From 46233befa4beb57b4f10225341b330880e3b30ac Mon Sep 17 00:00:00 2001 From: Dian Chen Date: Fri, 4 Mar 2022 21:55:18 +0800 Subject: [PATCH 6/6] Update: different concat methods on blanks, enters and OOV. --- pycorrector/macbert/infer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pycorrector/macbert/infer.py b/pycorrector/macbert/infer.py index 0af20adf..fef15aec 100644 --- a/pycorrector/macbert/infer.py +++ b/pycorrector/macbert/infer.py @@ -76,14 +76,23 @@ def predict_with_error_detail(self, sentence_list): sentence_list = [sentence_list] corrected_texts = self.model.predict(sentence_list) - def get_errors(_corrected_text, _origin_text, blanks_cleaned=False): + def get_errors(_corrected_text, _origin_text): sub_details = [] + + # Flags, we found that blanks are remained but enters are cleaned. + blanks_cleaned = False + enter_cleaned = True + for i, ori_char in enumerate(_origin_text): if ori_char == " ": # add blank word _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blanks_cleaned else i + 1:] continue - if ori_char in ['“', '”', '‘', '’', '\n', '…', '—']: + if ori_char == "\n": + # add enter word + _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if enter_cleaned else i + 1:] + continue + if ori_char in ['“', '”', '‘', '’', '琊', '…', '—', '擤']: # add unk word _corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] continue