Skip to content

Commit

Permalink
Merge pull request #4 from rminchev1/feature/service_mode
Browse files Browse the repository at this point in the history
Add OpenAI API key parameters and refactor file handling
  • Loading branch information
rminchev1 authored Oct 13, 2023
2 parents 4395a3d + 665f06b commit 39e4207
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
14 changes: 11 additions & 3 deletions gpt_engineer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class TokenUsage:


class AI:
def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""):
def __init__(
self, model_name="gpt-4", temperature=0.1, azure_endpoint="", openai_api_key=""
):
"""
Initialize the AI class.
Expand All @@ -53,7 +55,9 @@ def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""):
self.model_name = (
fallback_model(model_name) if azure_endpoint == "" else model_name
)
self.llm = create_chat_model(self, self.model_name, self.temperature)
self.llm = create_chat_model(
self, self.model_name, self.temperature, openai_api_key
)
self.tokenizer = get_tokenizer(self.model_name)
logger.debug(f"Using model {self.model_name} with llm {self.llm}")

Expand Down Expand Up @@ -339,7 +343,7 @@ def fallback_model(model: str) -> str:
return "gpt-3.5-turbo"


def create_chat_model(self, model: str, temperature) -> BaseChatModel:
def create_chat_model(self, model: str, temperature, openai_api_key="") -> BaseChatModel:
"""
Create a chat model with the specified model name and temperature.
Expand All @@ -349,6 +353,8 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
The name of the model to create.
temperature : float
The temperature to use for the model.
openai_api_key : str
OpenAI API key
Returns
-------
Expand All @@ -362,6 +368,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
deployment_name=model,
openai_api_type="azure",
streaming=True,
openai_api_key=openai_api_key,
)
# Fetch available models from OpenAI API
supported = [model["id"] for model in openai.Model.list()["data"]]
Expand All @@ -374,6 +381,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
temperature=temperature,
streaming=True,
client=openai.ChatCompletion,
openai_api_key=openai_api_key,
)


Expand Down
28 changes: 24 additions & 4 deletions gpt_engineer/chat_to_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def parse_chat(chat) -> List[Tuple[str, str]]:

# Get all the text before the first ``` block
readme = chat.split("```")[0]
files.append(("README.md", readme))
files.append(("LAST_MODIFICATION_README.md", readme))

# Return the files
return files
Expand All @@ -69,6 +69,24 @@ def to_files(chat, workspace):
workspace[file_name] = file_content


def to_files_(chat, dbs):
"""
Parse the chat and add all extracted files to the workspace.
Parameters
----------
chat : str
The chat to parse.
workspace : dict
The workspace to add the files to.
"""
dbs.project_metadata["all_output.txt"] = chat

files = parse_chat(chat)
for file_name, file_content in files:
dbs.workspace[file_name] = file_content


def overwrite_files(chat, dbs):
"""
Replace the AI files with the older local files.
Expand All @@ -82,12 +100,14 @@ def overwrite_files(chat, dbs):
replace_files : dict
A dictionary mapping file names to file paths of the local files.
"""
dbs.workspace["all_output.txt"] = chat
dbs.project_metadata[
"all_output.txt"
] = chat # files_info = get_code_strings(dbs.project_metadata)

files = parse_chat(chat)
for file_name, file_content in files:
if file_name == "README.md":
dbs.workspace["LAST_MODIFICATION_README.md"] = file_content
if file_name == "LAST_MODIFICATION_README.md":
dbs.project_metadata["LAST_MODIFICATION_README.md"] = file_content
else:
dbs.workspace[file_name] = file_content

Expand Down
1 change: 1 addition & 0 deletions gpt_engineer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class DBs:
input: DB
workspace: DB
archive: DB
project_metadata: DB


def archive(dbs: DBs) -> None:
Expand Down
1 change: 1 addition & 0 deletions gpt_engineer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def main(
Path(__file__).parent / "preprompts"
), # Loads preprompts from the preprompts directory
archive=DB(archive_path),
project_metadata=DB(base_metadata_path),
)

if steps_config not in [
Expand Down
18 changes: 14 additions & 4 deletions gpt_engineer/steps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
import re
import subprocess

Expand All @@ -14,6 +15,7 @@
get_code_strings,
overwrite_files,
to_files,
to_files_,
)
from gpt_engineer.db import DBs
from gpt_engineer.file_selector import FILE_LIST_NAME, ask_for_files
Expand Down Expand Up @@ -81,7 +83,14 @@ def curr_fn() -> str:
def simple_gen(ai: AI, dbs: DBs) -> List[Message]:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1].content.strip(), dbs.workspace)

SERVICE_MODE = os.environ.get("SERVICE_MODE", False)

if SERVICE_MODE:
to_files_(messages[-1].content.strip(), dbs)
else:
to_files(messages[-1].content.strip(), dbs.workspace)

return messages


Expand Down Expand Up @@ -273,7 +282,8 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
"Do not use placeholders, use example values (like . for a folder argument) "
"if necessary.\n"
),
user="Information about the codebase:\n\n" + dbs.workspace["all_output.txt"],
user="Information about the codebase:\n\n"
+ dbs.project_metadata["all_output.txt"],
step_name=curr_fn(),
)
print()
Expand Down Expand Up @@ -357,7 +367,7 @@ def improve_existing_code(ai: AI, dbs: DBs):
to sent the formatted prompt to the LLM.
"""

files_info = get_code_strings(dbs.input) # this only has file names not paths
files_info = get_code_strings(dbs.input)

messages = [
ai.fsystem(setup_sys_prompt_existing_code(dbs)),
Expand Down Expand Up @@ -441,7 +451,7 @@ class Config(str, Enum):
Config.SIMPLE: [
simple_gen,
gen_entrypoint,
execute_entrypoint,
# execute_entrypoint,
],
Config.TDD: [
gen_spec,
Expand Down

0 comments on commit 39e4207

Please sign in to comment.