Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI API key parameter and refactor file handling #4

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions gpt_engineer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class TokenUsage:


class AI:
def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""):
def __init__(
self, model_name="gpt-4", temperature=0.1, azure_endpoint="", openai_api_key=""
):
"""
Initialize the AI class.

Expand All @@ -53,7 +55,9 @@ def __init__(self, model_name="gpt-4", temperature=0.1, azure_endpoint=""):
self.model_name = (
fallback_model(model_name) if azure_endpoint == "" else model_name
)
self.llm = create_chat_model(self, self.model_name, self.temperature)
self.llm = create_chat_model(
self, self.model_name, self.temperature, openai_api_key
)
self.tokenizer = get_tokenizer(self.model_name)
logger.debug(f"Using model {self.model_name} with llm {self.llm}")

Expand Down Expand Up @@ -339,7 +343,7 @@ def fallback_model(model: str) -> str:
return "gpt-3.5-turbo"


def create_chat_model(self, model: str, temperature) -> BaseChatModel:
def create_chat_model(self, model: str, temperature, openai_api_key="") -> BaseChatModel:
"""
Create a chat model with the specified model name and temperature.

Expand All @@ -349,6 +353,8 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
The name of the model to create.
temperature : float
The temperature to use for the model.
openai_api_key : str
OpenAI API key

Returns
-------
Expand All @@ -362,6 +368,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
deployment_name=model,
openai_api_type="azure",
streaming=True,
openai_api_key=openai_api_key,
)
# Fetch available models from OpenAI API
supported = [model["id"] for model in openai.Model.list()["data"]]
Expand All @@ -374,6 +381,7 @@ def create_chat_model(self, model: str, temperature) -> BaseChatModel:
temperature=temperature,
streaming=True,
client=openai.ChatCompletion,
openai_api_key=openai_api_key,
)


Expand Down
28 changes: 24 additions & 4 deletions gpt_engineer/chat_to_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def parse_chat(chat) -> List[Tuple[str, str]]:

# Get all the text before the first ``` block
readme = chat.split("```")[0]
files.append(("README.md", readme))
files.append(("LAST_MODIFICATION_README.md", readme))

# Return the files
return files
Expand All @@ -69,6 +69,24 @@ def to_files(chat, workspace):
workspace[file_name] = file_content


def to_files_(chat, dbs):
"""
Parse the chat and add all extracted files to the workspace.

Parameters
----------
chat : str
The chat to parse.
workspace : dict
The workspace to add the files to.
"""
dbs.project_metadata["all_output.txt"] = chat

files = parse_chat(chat)
for file_name, file_content in files:
dbs.workspace[file_name] = file_content


def overwrite_files(chat, dbs):
"""
Replace the AI files with the older local files.
Expand All @@ -82,12 +100,14 @@ def overwrite_files(chat, dbs):
replace_files : dict
A dictionary mapping file names to file paths of the local files.
"""
dbs.workspace["all_output.txt"] = chat
dbs.project_metadata[
"all_output.txt"
] = chat # files_info = get_code_strings(dbs.project_metadata)

files = parse_chat(chat)
for file_name, file_content in files:
if file_name == "README.md":
dbs.workspace["LAST_MODIFICATION_README.md"] = file_content
if file_name == "LAST_MODIFICATION_README.md":
dbs.project_metadata["LAST_MODIFICATION_README.md"] = file_content
else:
dbs.workspace[file_name] = file_content

Expand Down
1 change: 1 addition & 0 deletions gpt_engineer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class DBs:
input: DB
workspace: DB
archive: DB
project_metadata: DB


def archive(dbs: DBs) -> None:
Expand Down
1 change: 1 addition & 0 deletions gpt_engineer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def main(
Path(__file__).parent / "preprompts"
), # Loads preprompts from the preprompts directory
archive=DB(archive_path),
project_metadata=DB(base_metadata_path),
)

if steps_config not in [
Expand Down
18 changes: 14 additions & 4 deletions gpt_engineer/steps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
import re
import subprocess

Expand All @@ -14,6 +15,7 @@
get_code_strings,
overwrite_files,
to_files,
to_files_,
)
from gpt_engineer.db import DBs
from gpt_engineer.file_selector import FILE_LIST_NAME, ask_for_files
Expand Down Expand Up @@ -81,7 +83,14 @@ def curr_fn() -> str:
def simple_gen(ai: AI, dbs: DBs) -> List[Message]:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1].content.strip(), dbs.workspace)

SERVICE_MODE = os.environ.get("SERVICE_MODE", False)

if SERVICE_MODE:
to_files_(messages[-1].content.strip(), dbs)
else:
to_files(messages[-1].content.strip(), dbs.workspace)

return messages


Expand Down Expand Up @@ -273,7 +282,8 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
"Do not use placeholders, use example values (like . for a folder argument) "
"if necessary.\n"
),
user="Information about the codebase:\n\n" + dbs.workspace["all_output.txt"],
user="Information about the codebase:\n\n"
+ dbs.project_metadata["all_output.txt"],
step_name=curr_fn(),
)
print()
Expand Down Expand Up @@ -357,7 +367,7 @@ def improve_existing_code(ai: AI, dbs: DBs):
to sent the formatted prompt to the LLM.
"""

files_info = get_code_strings(dbs.input) # this only has file names not paths
files_info = get_code_strings(dbs.input)

messages = [
ai.fsystem(setup_sys_prompt_existing_code(dbs)),
Expand Down Expand Up @@ -441,7 +451,7 @@ class Config(str, Enum):
Config.SIMPLE: [
simple_gen,
gen_entrypoint,
execute_entrypoint,
# execute_entrypoint,
],
Config.TDD: [
gen_spec,
Expand Down
Loading