From 32192a52a0f2269c442d2f753a0c6240288ec99c Mon Sep 17 00:00:00 2001 From: shibing624 Date: Sun, 13 Oct 2024 13:58:53 +0800 Subject: [PATCH] update eval fn --- examples/evaluate_models/evaluate_models.py | 8 +++++--- pycorrector/utils/evaluate_utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/evaluate_models/evaluate_models.py b/examples/evaluate_models/evaluate_models.py index 32c4f4e7..1dff3ac1 100644 --- a/examples/evaluate_models/evaluate_models.py +++ b/examples/evaluate_models/evaluate_models.py @@ -84,14 +84,16 @@ def main(args): model_type='chatglm', peft_name="shibing624/chatglm3-6b-csc-chinese-lora") if args.data == 'sighan': - eval_model_batch(m.correct_batch) + eval_model_batch(m.correct_batch, prompt_template_name='vicuna') # Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100 # elif args.data == 'ec_law': - eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv")) + eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"), + prompt_template_name='vicuna') # elif args.data == 'mcsc': - eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv")) + eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"), + prompt_template_name='vicuna') # elif args.model == 'qwen1.5b': from pycorrector.gpt.gpt_corrector import GptCorrector diff --git a/pycorrector/utils/evaluate_utils.py b/pycorrector/utils/evaluate_utils.py index 0d8c52f2..3524a484 100644 --- a/pycorrector/utils/evaluate_utils.py +++ b/pycorrector/utils/evaluate_utils.py @@ -12,7 +12,7 @@ sighan_2015_path = os.path.join(pwd_path, '../data/sighan2015_test.tsv') -def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True): +def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True, **kwargs): """ SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本 Args: @@ -40,7 +40,7 @@ def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True) src = parts[0] tgt = parts[1] - r = correct_fn(src) + r = correct_fn(src, **kwargs) tgt_pred, pred_detail = r['target'], r['errors'] if verbose: print() @@ -80,7 +80,7 @@ def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True) return acc, precision, recall, f1 -def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True): +def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True, **kwargs): """ SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本 Args: @@ -113,7 +113,7 @@ def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True): srcs.append(src) tgts.append(tgt) - res = correct_fn(srcs) + res = correct_fn(srcs, **kwargs) for each_res, src, tgt in zip(res, srcs, tgts): pred_detail = '' if isinstance(each_res, str):