Skip to content

Commit

Permalink
Remove references to PromptBuilder (#289)
Browse files Browse the repository at this point in the history
* Move PromptBuilder build_prompt call into DefaultCompletionAgent.

* Updated return type.

* Fixed prompt return.

* Fixing prompt return.

* Removed references to PromptBuilder class

* Removed references to PromptBuilder.py

* Increment version.txt.

* Adding better error handling for prompt builder.

* Fixing import sequence.

* Adding tests.

* increasing timeout for subprocess runtime.
  • Loading branch information
EmbeddedDevops1 authored Feb 20, 2025
1 parent c1d7386 commit 53f72c7
Show file tree
Hide file tree
Showing 12 changed files with 473 additions and 325 deletions.
8 changes: 4 additions & 4 deletions cover_agent/AgentCompletionABC.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def adapt_test_command_for_a_single_test_via_ai(
Returns:
Tuple[str, int, int, str]:
A 4-element tuple containing:
- The AI-generated modified command line (string),
- The input token count (int),
- The output token count (int),
- The AI-generated single-test command line (string) or None upon failure.
- The input token count (int).
- The output token count (int).
- The final constructed prompt (string).
"""
pass
pass
16 changes: 7 additions & 9 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import List

from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import PromptBuilder, adapt_test_command_for_a_single_test_via_ai
from cover_agent.UnitTestGenerator import UnitTestGenerator
from cover_agent.UnitTestValidator import UnitTestValidator
from cover_agent.UnitTestDB import UnitTestDB
Expand Down Expand Up @@ -41,12 +40,9 @@ def __init__(self, args, agent_completion: AgentCompletionABC = None):
if agent_completion:
self.agent_completion = agent_completion
else:
# Default to using the DefaultAgentCompletion object with the PromptBuilder and AICaller
# Default to using the DefaultAgentCompletion object with AICaller
self.ai_caller = AICaller(model=args.model, api_base=args.api_base, max_tokens=8192)
self.prompt_builder = PromptBuilder()
self.agent_completion = DefaultAgentCompletion(
builder=self.prompt_builder, caller=self.ai_caller
)
self.agent_completion = DefaultAgentCompletion(caller=self.ai_caller)

self.test_gen = UnitTestGenerator(
source_file_path=args.source_file_path,
Expand Down Expand Up @@ -99,8 +95,10 @@ def parse_command_to_run_only_a_single_test(self, args):
f"Failed to adapt test command for running a single test: {test_command}"
)
else:
new_command_line = adapt_test_command_for_a_single_test_via_ai(
args, test_file_relative_path, test_command
new_command_line = self.agent_completion.adapt_test_command_for_a_single_test_via_ai(
test_file_relative_path=test_file_relative_path,
test_command=test_command,
project_root_dir=self.args.test_command_dir,
)
if new_command_line:
args.test_command_original = test_command
Expand Down Expand Up @@ -227,7 +225,7 @@ def run_test_gen(
test_result = self.test_validator.validate_test(generated_test)

# Insert the test result into the database
test_result["prompt"] = self.test_gen.prompt["user"]
test_result["prompt"] = self.test_gen.prompt
self.test_db.insert_attempt(test_result)
except AttributeError as e:
self.logger.error(
Expand Down
98 changes: 78 additions & 20 deletions cover_agent/DefaultAgentCompletion.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,71 @@
from cover_agent.AgentCompletionABC import AgentCompletionABC
from cover_agent.PromptBuilder import PromptBuilder
from cover_agent.AICaller import AICaller
from cover_agent.CustomLogger import CustomLogger
from cover_agent.settings.config_loader import get_settings
from cover_agent.utils import load_yaml

from jinja2 import Environment, StrictUndefined
from typing import Tuple


class DefaultAgentCompletion(AgentCompletionABC):
"""
A default implementation of AgentCompletionABC that relies on TOML-based
prompt templates for each method. It uses a PromptBuilder to construct the
prompt templates for each method. It uses _build_prompt() to construct the
prompt from the appropriate TOML file, then calls an AI model via AICaller
to get the response.
"""

def __init__(self, builder: PromptBuilder, caller: AICaller):
def __init__(self, caller: AICaller):
"""
Initializes the DefaultAgentCompletion.
Args:
builder (PromptBuilder): A utility class for building prompts from TOML templates.
caller (AICaller): A class responsible for sending the prompt to an AI model and returning the response.
"""
self.builder = builder
self.caller = caller
self.logger = CustomLogger.get_logger(__name__)

def _build_prompt(self, file: str, **kwargs) -> dict:
"""
Internal helper that builds {"system": ..., "user": ...} for the model
by loading Jinja2 templates from TOML-based settings.
The `file` argument corresponds to the name/key in your TOML file,
e.g. "analyze_test_against_context". All other variables are passed
in via **kwargs. The TOML's system/user templates may reference these
variables using Jinja2 syntax, e.g. {{ language }} or {{ test_file_content }}.
Raises:
ValueError: If the TOML config does not contain valid 'system' and 'user' keys.
RuntimeError: If an error occurs while rendering the templates.
"""
from jinja2 import Environment, StrictUndefined

environment = Environment(undefined=StrictUndefined)

try:
# 1. Fetch the prompt config from your TOML-based settings
settings = get_settings().get(file)
if not settings or not hasattr(settings, "system") or not hasattr(settings, "user"):
msg = f"Could not find valid system/user prompt settings for: {file}"
self.logger.error(msg)
raise ValueError(msg)

# 2. Render system & user templates with the passed-in kwargs
system_prompt = environment.from_string(settings.system).render(**kwargs)
user_prompt = environment.from_string(settings.user).render(**kwargs)

except ValueError:
# Re-raise the ValueError above so callers can catch it if needed.
raise
except Exception as e:
# Any other rendering or environment errors will be re-raised as RuntimeError
error_msg = f"Error rendering prompt for '{file}': {e}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)

return {"system": system_prompt, "user": user_prompt}

def generate_tests(
self,
Expand Down Expand Up @@ -64,7 +108,7 @@ def generate_tests(
- The output token count (int),
- The final constructed prompt sent to the AI (str).
"""
prompt = self.builder.build_prompt(
prompt = self._build_prompt(
file="test_generation_prompt",
source_file_name=source_file_name,
max_tests=max_tests,
Expand All @@ -79,7 +123,7 @@ def generate_tests(
failed_tests_section=failed_tests_section,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
return response, prompt_tokens, completion_tokens, prompt["user"]

def analyze_test_failure(
self,
Expand Down Expand Up @@ -117,7 +161,7 @@ def analyze_test_failure(
- The output token count (int),
- The final constructed prompt (str).
"""
prompt = self.builder.build_prompt(
prompt = self._build_prompt(
file="analyze_test_run_failure",
source_file_name=source_file_name,
source_file=source_file,
Expand All @@ -127,7 +171,7 @@ def analyze_test_failure(
test_file_name=test_file_name,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
return response, prompt_tokens, completion_tokens, prompt["user"]

def analyze_test_insert_line(
self,
Expand Down Expand Up @@ -157,15 +201,15 @@ def analyze_test_insert_line(
- The output token count (int),
- The final constructed prompt (str).
"""
prompt = self.builder.build_prompt(
prompt = self._build_prompt(
file="analyze_suite_test_insert_line",
language=language,
test_file_numbered=test_file_numbered,
test_file_name=test_file_name,
additional_instructions_text=additional_instructions_text,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
return response, prompt_tokens, completion_tokens, prompt["user"]

def analyze_test_against_context(
self,
Expand Down Expand Up @@ -199,15 +243,15 @@ def analyze_test_against_context(
- The output token count (int),
- The final constructed prompt (str).
"""
prompt = self.builder.build_prompt(
prompt = self._build_prompt(
file="analyze_test_against_context",
language=language,
test_file_content=test_file_content,
test_file_name_rel=test_file_name_rel,
context_files_names_rel=context_files_names_rel,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
return response, prompt_tokens, completion_tokens, prompt["user"]

def analyze_suite_test_headers_indentation(
self,
Expand All @@ -234,14 +278,14 @@ def analyze_suite_test_headers_indentation(
- The output token count (int),
- The final constructed prompt (str).
"""
prompt = self.builder.build_prompt(
prompt = self._build_prompt(
file="analyze_suite_test_headers_indentation",
language=language,
test_file_name=test_file_name,
test_file=test_file,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt
return response, prompt_tokens, completion_tokens, prompt["user"]

def adapt_test_command_for_a_single_test_via_ai(
self,
Expand All @@ -263,16 +307,30 @@ def adapt_test_command_for_a_single_test_via_ai(
Returns:
Tuple[str, int, int, str]: A 4-element tuple containing:
- The AI-generated single-test command (str, often YAML),
- The new single-test command string (or None if error),
- The input token count (int),
- The output token count (int),
- The final constructed prompt (str).
"""
prompt = self.builder.build_prompt(
file="adapt_test_command_for_a_single_test_via_ai",
prompt = self._build_prompt(
"adapt_test_command_for_a_single_test_via_ai",
test_file_relative_path=test_file_relative_path,
test_command=test_command,
project_root_dir=project_root_dir,
)
response, prompt_tokens, completion_tokens = self.caller.call_model(prompt)
return response, prompt_tokens, completion_tokens, prompt

# Call the model
response_str, prompt_tokens, completion_tokens = self.caller.call_model(prompt)

# Now parse the response_str as YAML, and extract "new_command_line".
new_command_line = None
try:
response_yaml = load_yaml(response_str)
if "new_command_line" in response_yaml:
new_command_line = response_yaml["new_command_line"].strip()
except Exception as e:
self.logger.error(
f"Failed parsing YAML for adapt_test_command. response_yaml: {response_str}. Error: {e}"
)

return new_command_line, prompt_tokens, completion_tokens, prompt["user"]
Loading

0 comments on commit 53f72c7

Please sign in to comment.