You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Where function _clean_up_code is imported via from opencompass.datasets.humanevalx import _clean_up_code.
The definition of the funciton requires three arguments, (File: opencompass/datasets/humanevalx.py#L172
def_clean_up_code(text: str, language_type: str, reference) ->str:
"""Cleans up the generated code."""
I think the reference should come from the "gold" in the generated prediction results.
Therefore, I try to update the code as below:
@@ -96,7 +96,7 @@ def collect_preds(filename: str):
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
print(f'No predictions found for {filename}')
- return FAILED, None, None
+ return FAILED, None, None, None
else:
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
@@ -106,10 +106,14 @@ def collect_preds(filename: str):
ori_prompt_strs = [
preds[str(i)]['origin_prompt'] for i in range(len(preds))
]
+ gold_strs = [
+ preds[str(i)]['gold'] for i in range(len(preds))
+ ]
else:
filename = partial_filename
pred_strs = []
ori_prompt_strs = []
+ gold_strs = []
i = 1
while osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
@@ -121,7 +125,11 @@ def collect_preds(filename: str):
ori_prompt_strs += [
preds[str(i)]['origin_prompt'] for i in range(len(preds))
]
- return SUCCEED, ori_prompt_strs, pred_strs
+ gold_strs += [
+ preds[str(i)]['gold'] for i in range(len(preds))
+ ]
+
+ return SUCCEED, ori_prompt_strs, pred_strs, gold_strs
def main():
@@ -153,16 +161,21 @@ def main():
filename = get_infer_output_path(
model, dataset, osp.join(cfg.work_dir, 'predictions'))
- succeed, ori_prompt_strs, pred_strs = collect_preds(filename)
+ succeed, ori_prompt_strs, pred_strs, gold_strs = collect_preds(
+ filename)
if not succeed:
continue
+ lang = None
# infer the language type
for k, v in _LANGUAGE_NAME_DICT.items():
if k in dataset_abbr:
lang = k
task = v
break
+ if lang is None:
+ logger.warning(f'Not a HumanEvalX dataset: {dataset_abbr}')
+ continue
# special postprocess for GPT
if model_abbr in [
@@ -187,8 +200,8 @@ def main():
else:
predictions = [{
'task_id': f'{task}/{i}',
- 'generation': _clean_up_code(pred, lang),
- } for i, pred in enumerate(pred_strs)]
+ 'generation': _clean_up_code(pred, lang, gold),
+ } for i, (pred, gold) in enumerate(zip(pred_strs, gold_strs))]
# save processed results if not exists
result_file_path = os.path.join(cfg['work_dir'], 'humanevalx',
The complete updated code is as below
importargparseimportjsonimportosimportos.pathasospimportreimportmmenginefrommmengineimportConfigfrommmengine.utilsimportmkdir_or_existfromopencompass.datasets.humanevalximport_clean_up_codefromopencompass.utilsimport (dataset_abbr_from_cfg, get_infer_output_path,
get_logger, model_abbr_from_cfg)
defparse_args():
parser=argparse.ArgumentParser(
description='Collect Humanevalx dataset predictions.')
parser.add_argument('config', help='Config file path')
parser.add_argument('-r',
'--reuse',
nargs='?',
type=str,
const='latest',
help='Reuse previous outputs & results, and run any ''missing jobs presented in the config. If its ''argument is not specified, the latest results in ''the work_dir will be reused. The argument should ''also be a specific timestamp, e.g. 20230516_144254'),
args=parser.parse_args()
returnargs_LANGUAGE_NAME_DICT= {
'cpp': 'CPP',
'go': 'Go',
'java': 'Java',
'js': 'JavaScript',
'python': 'Python',
'rust': 'Rust',
}
FAILED=0SUCCEED=1defgpt_python_postprocess(ori_prompt: str, text: str) ->str:
"""Better answer postprocessor for better instruction-aligned models like GPT."""if'```'intext:
blocks=re.findall(r'```(.*?)```', text, re.DOTALL)
iflen(blocks) ==0:
text=text.split('```')[1] # fall back to default strategyelse:
text=blocks[0] # fetch the first code blockifnottext.startswith('\n'): # in case starting with ```pythontext=text[max(text.find('\n') +1, 0):]
match_ori=re.search(r'def(.*?)\(', ori_prompt)
match=re.search(r'def(.*?)\(', text)
ifmatch:
ifmatch.group() ==match_ori.group():
text=re.sub('def(.*?)\n', '', text, count=1)
forc_index, cinenumerate(text[:5]):
ifc!=' ':
text=' '* (4-c_index) +textbreaktext=text.split('\n\n\n')[0]
returntextdefwizardcoder_postprocess(text: str) ->str:
"""Postprocess for WizardCoder Models."""if'```'intext:
blocks=re.findall(r'```(.*?)```', text, re.DOTALL)
iflen(blocks) ==0:
text=text.split('```')[1] # fall back to default strategyelse:
text=blocks[0] # fetch the first code blockifnottext.startswith('\n'): # in case starting with ```pythontext=text[max(text.find('\n') +1, 0):]
else:
match=re.search(r'Here(.*?)\n', text)
ifmatch:
text=re.sub('Here(.*?)\n', '', text, count=1)
returntextdefcollect_preds(filename: str):
# in case the prediction is partialroot, ext=osp.splitext(filename)
partial_filename=root+'_0'+ext# collect all the prediction resultsifnotosp.exists(osp.realpath(filename)) andnotosp.exists(
osp.realpath(partial_filename)):
print(f'No predictions found for {filename}')
returnFAILED, None, None, Noneelse:
ifosp.exists(osp.realpath(filename)):
preds=mmengine.load(filename)
pred_strs= [
preds[str(i)]['prediction'] foriinrange(len(preds))
]
ori_prompt_strs= [
preds[str(i)]['origin_prompt'] foriinrange(len(preds))
]
gold_strs= [
preds[str(i)]['gold'] foriinrange(len(preds))
]
else:
filename=partial_filenamepred_strs= []
ori_prompt_strs= []
gold_strs= []
i=1whileosp.exists(osp.realpath(filename)):
preds=mmengine.load(filename)
filename=root+f'_{i}'+exti+=1pred_strs+= [
preds[str(i)]['prediction'] foriinrange(len(preds))
]
ori_prompt_strs+= [
preds[str(i)]['origin_prompt'] foriinrange(len(preds))
]
gold_strs+= [
preds[str(i)]['gold'] foriinrange(len(preds))
]
returnSUCCEED, ori_prompt_strs, pred_strs, gold_strsdefmain():
args=parse_args()
# initialize loggerlogger=get_logger(log_level='INFO')
cfg=Config.fromfile(args.config)
cfg.setdefault('work_dir', './outputs/default/')
assertargs.reuse, 'Please provide the experienment work dir.'ifargs.reuse:
ifargs.reuse=='latest':
ifnotos.path.exists(cfg.work_dir) ornotos.listdir(
cfg.work_dir):
logger.warning('No previous results to reuse!')
else:
dirs=os.listdir(cfg.work_dir)
dir_time_str=sorted(dirs)[-1]
else:
dir_time_str=args.reuselogger.info(f'Reusing experiements from {dir_time_str}')
# update "actual" work_dircfg['work_dir'] =osp.join(cfg.work_dir, dir_time_str)
formodelincfg.models:
model_abbr=model_abbr_from_cfg(model)
fordatasetincfg.datasets:
dataset_abbr=dataset_abbr_from_cfg(dataset)
filename=get_infer_output_path(
model, dataset, osp.join(cfg.work_dir, 'predictions'))
succeed, ori_prompt_strs, pred_strs, gold_strs=collect_preds(
filename)
ifnotsucceed:
continuelang=None# infer the language typefork, vin_LANGUAGE_NAME_DICT.items():
ifkindataset_abbr:
lang=ktask=vbreakiflangisNone:
logger.warning(f'Not a HumanEvalX dataset: {dataset_abbr}')
continue# special postprocess for GPTifmodel_abbrin [
'WizardCoder-1B-V1.0',
'WizardCoder-3B-V1.0',
'WizardCoder-15B-V1.0',
'WizardCoder-Python-13B-V1.0',
'WizardCoder-Python-34B-V1.0',
]:
predictions= [{
'task_id': f'{task}/{i}',
'generation': wizardcoder_postprocess(pred),
} fori, predinenumerate(pred_strs)]
elif'CodeLlama'notinmodel_abbrandlang=='python':
predictions= [{
'task_id':
f'{task}/{i}',
'generation':
gpt_python_postprocess(ori_prompt, pred),
} fori, (ori_prompt,
pred) inenumerate(zip(ori_prompt_strs, pred_strs))]
else:
predictions= [{
'task_id': f'{task}/{i}',
'generation': _clean_up_code(pred, lang, gold),
} fori, (pred, gold) inenumerate(zip(pred_strs, gold_strs))]
# save processed results if not existsresult_file_path=os.path.join(cfg['work_dir'], 'humanevalx',
model_abbr,
f'humanevalx_{lang}.json')
ifosp.exists(result_file_path):
logger.info(
f'File exists for {model_abbr}, skip copy from predictions.'# noqa
)
else:
mkdir_or_exist(osp.split(result_file_path)[0])
withopen(result_file_path, 'w') asf:
forpredinpredictions:
f.write(json.dumps(pred) +'\n')
if__name__=='__main__':
main()
The text was updated successfully, but these errors were encountered:
Missing one required positional argument in tools/collect_code_preds.py#L190
In file "tools/collect_code_preds.py#L190"
Where function
_clean_up_code
is imported viafrom opencompass.datasets.humanevalx import _clean_up_code
.The definition of the funciton requires three arguments, (File: opencompass/datasets/humanevalx.py#L172
I think the reference should come from the "gold" in the generated prediction results.
Therefore, I try to update the code as below:
The complete updated code is as below
The text was updated successfully, but these errors were encountered: