diff --git a/gpt_engineer/ai.py b/gpt_engineer/ai.py index db58d943f4..26ca81ae97 100644 --- a/gpt_engineer/ai.py +++ b/gpt_engineer/ai.py @@ -37,7 +37,9 @@ class TokenUsage: class AI: - def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""): + def __init__( + self, model_name="gpt-4", temperature=0.1, azure_endpoint="", openai_api_key="" + ): """ Initialize the AI class. @@ -53,7 +55,9 @@ def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""): self.model_name = ( fallback_model(model_name) if azure_endpoint == "" else model_name ) - self.llm = create_chat_model(self, self.model_name, self.temperature) + self.llm = create_chat_model( + self, self.model_name, self.temperature, openai_api_key + ) self.tokenizer = get_tokenizer(self.model_name) logger.debug(f"Using model {self.model_name} with llm {self.llm}") @@ -339,7 +343,7 @@ def fallback_model(model: str) -> str: return "gpt-3.5-turbo" -def create_chat_model(self, model: str, temperature) -> BaseChatModel: +def create_chat_model(self, model: str, temperature, openai_api_key="") -> BaseChatModel: """ Create a chat model with the specified model name and temperature. @@ -349,6 +353,8 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel: The name of the model to create. temperature : float The temperature to use for the model. + openai_api_key : str + OpenAI API key Returns ------- @@ -362,6 +368,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel: deployment_name=model, openai_api_type="azure", streaming=True, + openai_api_key=openai_api_key, ) # Fetch available models from OpenAI API supported = [model["id"] for model in openai.Model.list()["data"]] @@ -374,6 +381,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel: temperature=temperature, streaming=True, client=openai.ChatCompletion, + openai_api_key=openai_api_key, ) diff --git a/gpt_engineer/chat_to_files.py b/gpt_engineer/chat_to_files.py index 4f8d3d963b..ca8d6766ff 100644 --- a/gpt_engineer/chat_to_files.py +++ b/gpt_engineer/chat_to_files.py @@ -45,7 +45,7 @@ def parse_chat(chat) -> List[Tuple[str, str]]: # Get all the text before the first ``` block readme = chat.split("```")[0] - files.append(("README.md", readme)) + files.append(("LAST_MODIFICATION_README.md", readme)) # Return the files return files @@ -69,6 +69,24 @@ def to_files(chat, workspace): workspace[file_name] = file_content +def to_files_(chat, dbs): + """ + Parse the chat and add all extracted files to the workspace. + + Parameters + ---------- + chat : str + The chat to parse. + workspace : dict + The workspace to add the files to. + """ + dbs.project_metadata["all_output.txt"] = chat + + files = parse_chat(chat) + for file_name, file_content in files: + dbs.workspace[file_name] = file_content + + def overwrite_files(chat, dbs): """ Replace the AI files with the older local files. @@ -82,12 +100,14 @@ def overwrite_files(chat, dbs): replace_files : dict A dictionary mapping file names to file paths of the local files. """ - dbs.workspace["all_output.txt"] = chat + dbs.project_metadata[ + "all_output.txt" + ] = chat # files_info = get_code_strings(dbs.project_metadata) files = parse_chat(chat) for file_name, file_content in files: - if file_name == "README.md": - dbs.workspace["LAST_MODIFICATION_README.md"] = file_content + if file_name == "LAST_MODIFICATION_README.md": + dbs.project_metadata["LAST_MODIFICATION_README.md"] = file_content else: dbs.workspace[file_name] = file_content diff --git a/gpt_engineer/db.py b/gpt_engineer/db.py index 6664a83c2c..427d00d704 100644 --- a/gpt_engineer/db.py +++ b/gpt_engineer/db.py @@ -124,6 +124,7 @@ class DBs: input: DB workspace: DB archive: DB + project_metadata: DB def archive(dbs: DBs) -> None: diff --git a/gpt_engineer/main.py b/gpt_engineer/main.py index 7b411dc808..4d78203a47 100644 --- a/gpt_engineer/main.py +++ b/gpt_engineer/main.py @@ -77,6 +77,7 @@ def main( Path(__file__).parent / "preprompts" ), # Loads preprompts from the preprompts directory archive=DB(archive_path), + project_metadata=DB(base_metadata_path), ) if steps_config not in [ diff --git a/gpt_engineer/steps.py b/gpt_engineer/steps.py index 4f669a2412..b0be9425a8 100644 --- a/gpt_engineer/steps.py +++ b/gpt_engineer/steps.py @@ -1,4 +1,5 @@ import inspect +import os import re import subprocess @@ -14,6 +15,7 @@ get_code_strings, overwrite_files, to_files, + to_files_, ) from gpt_engineer.db import DBs from gpt_engineer.file_selector import FILE_LIST_NAME, ask_for_files @@ -81,7 +83,14 @@ def curr_fn() -> str: def simple_gen(ai: AI, dbs: DBs) -> List[Message]: """Run the AI on the main prompt and save the results""" messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn()) - to_files(messages[-1].content.strip(), dbs.workspace) + + SERVICE_MODE = os.environ.get("SERVICE_MODE", False) + + if SERVICE_MODE: + to_files_(messages[-1].content.strip(), dbs) + else: + to_files(messages[-1].content.strip(), dbs.workspace) + return messages @@ -273,7 +282,8 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]: "Do not use placeholders, use example values (like . for a folder argument) " "if necessary.\n" ), - user="Information about the codebase:\n\n" + dbs.workspace["all_output.txt"], + user="Information about the codebase:\n\n" + + dbs.project_metadata["all_output.txt"], step_name=curr_fn(), ) print() @@ -357,7 +367,7 @@ def improve_existing_code(ai: AI, dbs: DBs): to sent the formatted prompt to the LLM. """ - files_info = get_code_strings(dbs.input) # this only has file names not paths + files_info = get_code_strings(dbs.input) messages = [ ai.fsystem(setup_sys_prompt_existing_code(dbs)), @@ -441,7 +451,7 @@ class Config(str, Enum): Config.SIMPLE: [ simple_gen, gen_entrypoint, - execute_entrypoint, + # execute_entrypoint, ], Config.TDD: [ gen_spec,