Skip to content

Commit

Permalink
update eval fn
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 13, 2024
1 parent 49fc55d commit 32192a5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions examples/evaluate_models/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,16 @@ def main(args):
model_type='chatglm',
peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
if args.data == 'sighan':
eval_model_batch(m.correct_batch)
eval_model_batch(m.correct_batch, prompt_template_name='vicuna')
# Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
#
elif args.data == 'ec_law':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"))
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"),
prompt_template_name='vicuna')
#
elif args.data == 'mcsc':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"))
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"),
prompt_template_name='vicuna')
#
elif args.model == 'qwen1.5b':
from pycorrector.gpt.gpt_corrector import GptCorrector
Expand Down
8 changes: 4 additions & 4 deletions pycorrector/utils/evaluate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sighan_2015_path = os.path.join(pwd_path, '../data/sighan2015_test.tsv')


def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True):
def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True, **kwargs):
"""
SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
Args:
Expand Down Expand Up @@ -40,7 +40,7 @@ def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True)
src = parts[0]
tgt = parts[1]

r = correct_fn(src)
r = correct_fn(src, **kwargs)
tgt_pred, pred_detail = r['target'], r['errors']
if verbose:
print()
Expand Down Expand Up @@ -80,7 +80,7 @@ def eval_model_single(correct_fn, input_tsv_file=sighan_2015_path, verbose=True)
return acc, precision, recall, f1


def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True):
def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True, **kwargs):
"""
SIGHAN句级评估结果,设定需要纠错为正样本,无需纠错为负样本
Args:
Expand Down Expand Up @@ -113,7 +113,7 @@ def eval_model_batch(correct_fn, input_tsv_file=sighan_2015_path, verbose=True):
srcs.append(src)
tgts.append(tgt)

res = correct_fn(srcs)
res = correct_fn(srcs, **kwargs)
for each_res, src, tgt in zip(res, srcs, tgts):
pred_detail = ''
if isinstance(each_res, str):
Expand Down

0 comments on commit 32192a5

Please sign in to comment.