Skip to content

Commit

Permalink
Improve arg typing, linting, config safety
Browse files Browse the repository at this point in the history
  • Loading branch information
DeflateAwning committed Dec 8, 2023
1 parent 9b6eee8 commit fdd8ae5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
6 changes: 3 additions & 3 deletions camel/agents/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import copy
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

from camel.agents import (
ChatAgent,
Expand Down Expand Up @@ -167,8 +167,8 @@ def __init__(
else:
self.critic = None

def init_chat(self, phase_type: PhaseType = None,
placeholders=None, phase_prompt=None):
def init_chat(self, phase_type: Union[PhaseType, None],
placeholders, phase_prompt: str):
r"""Initializes the chat by resetting both the assistant and user
agents, and sending the system messages again to the agents using
chat messages. Returns the assistant's introductory message and the
Expand Down
54 changes: 35 additions & 19 deletions chatdev/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import time
from datetime import datetime
from typing import Union

from camel.agents import RolePlaying
from camel.configs import ChatGPTConfig
Expand All @@ -14,21 +15,31 @@
from chatdev.utils import log_and_print_online, now


def check_bool(s):
return s.lower() == "true"

def check_bool(s: Union[str, bool]) -> bool:
""" Normalizes a string or bool to a bool value.
String must be either "True" or "False" (case insensitive).
"""
if isinstance(s, bool):
return s
else:
if s.lower() == "true":
return True
elif s.lower() == "false":
return False
else:
raise ValueError(f"Cannot convert '{s}' in config to bool")

class ChatChain:

def __init__(self,
config_path: str = None,
config_phase_path: str = None,
config_role_path: str = None,
task_prompt: str = None,
project_name: str = None,
org_name: str = None,
config_path: str,
config_phase_path: str,
config_role_path: str,
task_prompt: str,
project_name: Union[str, None] = None,
org_name: Union[str, None] = None,
model_type: ModelType = ModelType.GPT_3_5_TURBO,
code_path: str = None) -> None:
code_path: Union[str, None] = None):
"""
Args:
Expand All @@ -38,6 +49,8 @@ def __init__(self,
task_prompt: the user input prompt for software
project_name: the user input name for software
org_name: the organization name of the human user
model_type: the model type for chatbot
code_path: the path to the code files, if working incrementally
"""

# load config file
Expand Down Expand Up @@ -70,6 +83,10 @@ def __init__(self,
incremental_develop=check_bool(self.config["incremental_develop"]))
self.chat_env = ChatEnv(self.chat_env_config)

if not check_bool(self.config["incremental_develop"]):
if self.code_path:
raise RuntimeError("code_path is given, but Phase Config specifies incremental_develop=False. code_path will be ignored.")

# the user input prompt will be self-improved (if set "self_improve": "True" in ChatChainConfig.json)
# the self-improvement is done in self.preprocess
self.task_prompt_raw = task_prompt
Expand Down Expand Up @@ -168,16 +185,12 @@ def get_logfilepath(self):
Returns:
start_time: time for starting making the software
log_filepath: path to the log
"""
start_time = now()
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
# directory = root + "/WareHouse/"
directory = os.path.join(root, "WareHouse")
log_filepath = os.path.join(directory,
"{}.log".format("_".join([self.project_name, self.org_name, start_time])))
log_filepath = os.path.join(directory, f"{self.project_name}_{self.org_name}_{self.start_time}.log")
return start_time, log_filepath

def pre_processing(self):
Expand All @@ -195,9 +208,9 @@ def pre_processing(self):
# logs with error trials are left in WareHouse/
if os.path.isfile(file_path) and not filename.endswith(".py") and not filename.endswith(".log"):
os.remove(file_path)
print("{} Removed.".format(file_path))
print(f"{file_path} Removed.")

software_path = os.path.join(directory, "_".join([self.project_name, self.org_name, self.start_time]))
software_path = os.path.join(directory, f"{self.project_name}_{self.org_name}_{self.start_time}")
self.chat_env.set_directory(software_path)

# copy config files to software path
Expand All @@ -207,6 +220,9 @@ def pre_processing(self):

# copy code files to software path in incremental_develop mode
if check_bool(self.config["incremental_develop"]):
if not self.code_path:
raise RuntimeError("code_path is not given, but working in incremental_develop mode.")

for root, dirs, files in os.walk(self.code_path):
relative_path = os.path.relpath(root, self.code_path)
target_dir = os.path.join(software_path, 'base', relative_path)
Expand All @@ -218,7 +234,7 @@ def pre_processing(self):
self.chat_env._load_from_hardware(os.path.join(software_path, 'base'))

# write task prompt to software
with open(os.path.join(software_path, self.project_name + ".prompt"), "w") as f:
with open(os.path.join(software_path, f"{self.project_name}.prompt"), "w") as f:
f.write(self.task_prompt_raw)

preprocess_msg = "**[Preprocessing]**\n\n"
Expand Down Expand Up @@ -306,7 +322,7 @@ def post_processing(self):
time.sleep(1)

shutil.move(self.log_filepath,
os.path.join(root + "/WareHouse", "_".join([self.project_name, self.org_name, self.start_time]),
os.path.join(root, "WareHouse", f"{self.project_name}_{self.org_name}_{self.start_time}",
os.path.basename(self.log_filepath)))

# @staticmethod
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_config(company: str) -> Tuple[str, str, str]:
help="Name of software (your software will be generated in WareHouse/name_org_timestamp)")
parser.add_argument('--model', type=str, default="GPT_3_5_TURBO", choices=['GPT_3_5_TURBO', 'GPT_4', 'GPT_4_32K'],
help="GPT Model, choose from {'GPT_3_5_TURBO','GPT_4','GPT_4_32K'}")
parser.add_argument('--path', type=str, default="",
parser.add_argument('--path', type=str, default=None,
help="Your file directory. If given, ChatDev will build upon your software in the Incremental mode.")
args = parser.parse_args()

Expand Down

0 comments on commit fdd8ae5

Please sign in to comment.