Skip to content

Commit

Permalink
JVM: Fix exception description in prompt (#489)
Browse files Browse the repository at this point in the history
This PR fixes the exceptions information shown in the JVM prompt to
allow LLM to read correct exception information that is needed to catch
in the generated harnesses. This PR also improves the JVM prompt by
adding exceptions details of object constructors / methods that create
the needed object for method invocation that is originally suggested in
#488. Because of its similarity, changes suggested in #488 is also added
in the PR.

---------

Signed-off-by: Arthur Chan <[email protected]>
  • Loading branch information
arthurscchan authored Jul 16, 2024
1 parent 6954e1e commit 9a5b247
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
12 changes: 9 additions & 3 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def __init__(self,
self._template_dir = template_dir
self.benchmark = benchmark
self.project_url = self._find_project_url(self.benchmark.project)
self.exceptions = set(self.benchmark.exceptions)

# Load templates.
self.base_template_file = self._find_template(template_dir, 'jvm_base.txt')
Expand Down Expand Up @@ -622,9 +623,12 @@ def _format_target_method(self, signature: str) -> str:

def _format_exceptions(self) -> str:
"""Formats the exception thrown from this method or constructor."""
if self.benchmark.exceptions:
return '<exceptions>' + '\n'.join(
self.benchmark.exceptions) + '</exceptions>'
if self.exceptions:
exception_str_list = [
f'<exception>{exp}</exception>' for exp in self.exceptions
]
return '<exceptions>\n' + '\n'.join(
exception_str_list) + '\n</exceptions>'

return ''

Expand Down Expand Up @@ -797,6 +801,7 @@ def _format_constructors(self) -> str:
constructor_sig = ctr.get('function_signature')
if constructor_sig:
constructors.append(f'<signature>{constructor_sig}</signature>')
self.exceptions.update(ctr.get('exceptions', []))

if constructors:
ctr_str = '\n'.join(constructors)
Expand All @@ -810,6 +815,7 @@ def _format_constructors(self) -> str:
function_sig = func.get('function_signature')
if not function_sig:
continue
self.exceptions.update(func.get('exceptions', []))
if is_static:
functions.append(f'<item><signature>{function_sig}</signature></item>')
else:
Expand Down
1 change: 1 addition & 0 deletions prompts/template_xml/jvm_requirement.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<item>Please avoid using any multithreading or multi-processing approach.</item>
<item>Please add import statements for necessary classes, except for classes in the java.lang package.</item>
<item>You must create the object before calling the target method.</item>
<item>You must catch java.lang.UnsupportOperationException.</item>
<item>Please use {HARNESS_NAME} as the Java class name.</item>
<item>{STATIC_OR_INSTANCE}</item>
<item>Do not create new variables with the same names as existing variables.
Expand Down

0 comments on commit 9a5b247

Please sign in to comment.