Skip to content

Commit

Permalink
添加env、log命令
Browse files Browse the repository at this point in the history
  • Loading branch information
iokk3732 committed Apr 25, 2024
1 parent 249505e commit 80c0670
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 44 deletions.
48 changes: 16 additions & 32 deletions airda/agent/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,35 @@ class LogLevel(Enum):
@singleton
class DataAgentEnv(Env):
# 向量化模型名称
embeddings_model_name: str | None
EMBEDDINGS_MODEL_NAME: str | None

# mongo环境变量
mongodb_uri: str | None
mongodb_db_name: str | None
mongodb_username: str | None
mongodb_password: str | None
MONGODB_URI: str | None
MONGODB_DB_NAME: str | None
MONGODB_USERNAME: str | None
MONGODB_PASSWORD: str | None

# openai配置
openai_api_key: str | None
model_name: str

# 知识库缓存路径
knowledge_path: str | None

# vanus 代理 appid
application_id: str | None
OPENAI_API_KEY: str | None
MODEL_NAME: str

# 最大embedding最大并发数
max_works: str | None

log_level: str | None
MAX_WORKS: str | None

def init(self):
self.embeddings_model_name = os.getenv(
self.EMBEDDINGS_MODEL_NAME = os.getenv(
"EMBEDDINGS_MODEL_NAME", "infgrad/stella-large-zh-v2"
)

# mongo环境变量
self.mongodb_uri = os.getenv("MONGODB_URI")
self.mongodb_db_name = os.getenv("MONGODB_DB_NAME")
self.mongodb_username = os.getenv("MONGODB_USERNAME")
self.mongodb_password = os.getenv("MONGODB_PASSWORD")
self.MONGODB_URI = os.getenv("MONGODB_URI")
self.MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME")
self.MONGODB_USERNAME = os.getenv("MONGODB_USERNAME")
self.MONGODB_PASSWORD = os.getenv("MONGODB_PASSWORD")

# openai配置
self.openai_api_key = os.getenv("OPENAI_KEY")
self.model_name = os.getenv("model_name")

# 知识库缓存路径
self.knowledge_path = os.getenv("KNOWLEDGE_PATH")

# vanus 代理 appid
self.application_id = os.getenv("APPID")
self.OPENAI_API_KEY = os.getenv("OPENAI_KEY")
self.MODEL_NAME = os.getenv("MODEL_NAME")

# 最大embedding最大并发数
self.max_works = os.getenv("MAX_WORKERS", "4")

self.log_level = os.getenv("LOG_LEVEL", LogLevel.INFO.value)
self.MAX_WORKS = os.getenv("MAX_WORKERS", "4")
2 changes: 1 addition & 1 deletion airda/agent/process_pool/DataAgentProcessPool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DataAgentProcessPool:

def __init__(self):
self.pool_executor = concurrent.futures.ProcessPoolExecutor(
max_workers=int(DataAgentEnv().max_works)
max_workers=int(DataAgentEnv().MAX_WORKS)
)

def submit(self, fn, /, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions airda/agent/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def __init__(self, context: Context):
super().load(StorageKey.DATASOURCE, DatasourceRepository)

def init_storage(self):
uri = DataAgentEnv().mongodb_uri
db_name = DataAgentEnv().mongodb_db_name
username = DataAgentEnv().mongodb_username
password = DataAgentEnv().mongodb_password
uri = DataAgentEnv().MONGODB_URI
db_name = DataAgentEnv().MONGODB_DB_NAME
username = DataAgentEnv().MONGODB_USERNAME
password = DataAgentEnv().MONGODB_PASSWORD
if username and password:
self.client = MongoClient(uri, username=username, password=password)
self.database = self.client[db_name]
Expand Down
33 changes: 26 additions & 7 deletions airda/cli/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import click
import yaml
from prompt_toolkit import HTML, PromptSession, print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.styles import Style

Expand Down Expand Up @@ -39,10 +40,11 @@
env_path = config_path + "/" + ".env"
log_path = config_path + "/" + "log_config.yml"
DataAgentEnv(env_path)
log_config = {}
try:
with open(log_path, "r") as f:
config = yaml.safe_load(f)
logging.config.dictConfig(config)
log_config = yaml.safe_load(f)
logging.config.dictConfig(log_config)
except Exception:
pass

Expand Down Expand Up @@ -290,37 +292,54 @@ def delete(name: str):


@main.group()
def load():
def env():
pass


@load.command()
@env.command()
@click.option(
"-p",
"--path",
type=str,
required=True,
help=".env文件路径",
)
def env(path: str):
def load(path: str):
import shutil
if os.path.exists(path):
shutil.copy(path, env_path)


@load.command()
@env.command()
def ls():
import json
print_formatted_text(FormattedText([('class:json', json.dumps(DataAgentEnv().__dict__, indent=4))]))


@main.group()
def log():
pass


@log.command()
@click.option(
"-p",
"--path",
type=str,
required=True,
help="log_config.yml文件路径",
)
def log(path: str):
def load(path: str):
import shutil
if os.path.exists(path):
shutil.copy(path, log_path)


@log.command()
def ls():
import json
print_formatted_text(FormattedText([('class:json', json.dumps(log_config, indent=4))]))


if __name__ == "__main__":
main()

0 comments on commit 80c0670

Please sign in to comment.