-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
思凡 叶
committed
Dec 20, 2024
1 parent
9c74538
commit f5ddb42
Showing
268 changed files
with
19,003 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
include README.md | ||
include README.zh.md | ||
include LICENSE | ||
recursive-include LICENSES * | ||
recursive-include darkit *.cu | ||
recursive-include darkit *.cpp | ||
recursive-include darkit/core/web/build * | ||
recursive-include darkit/tmp * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.1.10" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
import signal | ||
import click | ||
from darkit import __version__ | ||
from .src.pid import read_pid, save_pid, remove_pid_file | ||
|
||
|
||
@click.group() | ||
@click.version_option(version=__version__, prog_name="DarwinKit") | ||
def cli(): | ||
pass | ||
|
||
|
||
try: | ||
from darkit.lm.command import command as lm_command | ||
|
||
cli.add_command(lm_command) | ||
except ImportError as e: | ||
print("lm command module not found", e) | ||
|
||
|
||
@cli.command("start") | ||
@click.option("--port", type=int, default=8000, help="Web 服务端口") | ||
@click.option("--daemon", "-D", is_flag=True, help="是否以守护进程启动") | ||
def start_server(port: int, daemon: bool): | ||
""" | ||
开启 WEB 服务 | ||
""" | ||
from darkit.core.utils.server import start_uvicorn | ||
|
||
if read_pid(): | ||
click.echo("服务已在运行。") | ||
return | ||
|
||
p = start_uvicorn(port, daemon) | ||
if daemon: | ||
p.start() | ||
save_pid(p.pid) | ||
print(f"服务已在后台以守护进程启动,端口: {port}") | ||
os._exit(0) | ||
else: | ||
print(f"服务已启动,端口: {port}") | ||
p.start() | ||
p.join() | ||
|
||
|
||
@cli.command("stop") | ||
def stop_server(): | ||
""" | ||
停止 WEB 服务 | ||
""" | ||
pid = read_pid() | ||
if pid is None: | ||
click.echo("服务未在运行或 PID 文件不存在。") | ||
return | ||
|
||
try: | ||
os.kill(pid, signal.SIGTERM) # 发送终止信号 | ||
remove_pid_file() # 移除 PID 文件 | ||
click.echo(f"服务已停止,PID: {pid}") | ||
except ProcessLookupError: | ||
click.echo(f"没有找到 PID 为 {pid} 的进程。") | ||
remove_pid_file() # 移除 PID 文件以防错误 | ||
except Exception as e: | ||
click.echo(f"停止服务时发生错误: {e}") | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import click | ||
from dataclasses import MISSING | ||
from typing import Callable | ||
|
||
|
||
def dataclass_options(options: dict, default=None) -> Callable: | ||
def decorator(func: Callable) -> Callable: | ||
# 自动生成 click 选项 | ||
commands = [] | ||
for name, field in options.items(): | ||
option_default = default.__dict__[name] if default else field["default"] | ||
if option_default is MISSING: | ||
option_default = None | ||
|
||
if isinstance(field["type"], list): | ||
option_type = click.Choice(field["type"]) | ||
else: | ||
option_type = eval(field["type"]) | ||
|
||
# 添加 click 选项 | ||
commands.append( | ||
click.option( | ||
f"--{name}", | ||
default=option_default, | ||
type=option_type, | ||
show_default=True, | ||
) | ||
) | ||
|
||
def wrapped_command(*args, **kwargs): | ||
nonlocal func | ||
# 获取 kwargs 中存在与 _fields 的参数 | ||
relevant_kwargs = { | ||
k: v | ||
for k, v in kwargs.items() | ||
if k in [name for name, _ in options.items()] and v is not None | ||
} | ||
|
||
# 把 config 添加到 args 的最后 | ||
args = args + (relevant_kwargs,) | ||
return func(*args, **kwargs) | ||
|
||
# 保存原始函数的参数 | ||
if hasattr(func, "__doc__"): | ||
wrapped_command.__doc__ = func.__doc__ | ||
if hasattr(func, "__click_params__"): | ||
wrapped_command.__click_params__ = func.__click_params__ | ||
|
||
for option in commands: | ||
wrapped_command = option(wrapped_command) | ||
|
||
return wrapped_command | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
CLI_NAME = "darkit" | ||
|
||
|
||
def dict_to_cmd_args(d: dict) -> str: | ||
return " ".join([f"--{k} {v}" for k, v in d.items() if v not in [None, ""]]) | ||
|
||
|
||
def gen_train_command(type: str, model: str, mconf: dict, tconf: dict): | ||
mconf_args = dict_to_cmd_args(mconf) | ||
tconf_args = dict_to_cmd_args(tconf) | ||
return f"{CLI_NAME} {type} train {model} {mconf_args} {tconf_args}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from dataclasses import is_dataclass, fields, MISSING | ||
from enum import Enum | ||
from typing import Type, Literal, Optional, Union, get_origin, get_args | ||
|
||
|
||
def get_option_definer(conf_cls: Type, conf_comment: dict): | ||
""" | ||
生成配置选项字典, 将 py 数据类型转换为 json 可序列化的数据类型。 | ||
Args: | ||
conf_cls (Type): 需要是 dataclass 或者有 to_options 方法。to_options 方法返回一个字典,结构与 options 相同。 | ||
conf_comment (dict): 配置字段的注释信息。 | ||
Returns: | ||
dict: 配置选项字典。 | ||
Raises: | ||
ValueError: 如果 conf_cls 不是 dataclass 且没有 to_options 方法。 | ||
""" | ||
has_method = hasattr(conf_cls, "to_options") and callable( | ||
getattr(conf_cls, "to_options") | ||
) | ||
if has_method: | ||
return conf_cls.to_options() | ||
elif is_dataclass(conf_cls): | ||
_fields = fields(conf_cls) | ||
|
||
options = dict() | ||
|
||
for field in _fields: | ||
option_default = field.default | ||
if option_default is MISSING: | ||
option_default = None | ||
|
||
option_type = field.type | ||
required = True | ||
if field.name == "device": | ||
option_type = Literal["cuda", "cpu"] | ||
if isinstance(option_type, type): | ||
if issubclass(option_type, Enum): # type: ignore | ||
all_values = [option.value for option in option_type] | ||
option_type = all_values | ||
else: | ||
option_type = option_type.__name__ | ||
elif get_origin(option_type) is Union: | ||
option_type = "str" | ||
elif get_origin(option_type) is Literal: | ||
option_type = get_args(option_type) | ||
elif get_origin(option_type) is Optional: | ||
required = False | ||
option_type = get_args(option_type)[0] | ||
else: | ||
required = False | ||
option_type = "str" | ||
|
||
comment_dict = conf_comment.get(field.name, {}) | ||
description = comment_dict.get("description") | ||
range = comment_dict.get("range") | ||
comment = f"{description} {range}" if description else None | ||
options[field.name] = { | ||
"default": option_default, | ||
"type": option_type, | ||
"required": required, | ||
"comment": comment, | ||
} | ||
|
||
return options | ||
else: | ||
raise ValueError( | ||
f"{conf_cls.__name__} should be a dataclass or have a to_options method." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
from pathlib import Path | ||
from darkit.core.utils import DSPIKE_LLM_HOME | ||
|
||
PID_FILE = Path(os.path.expanduser(DSPIKE_LLM_HOME)) / ".server.pid" | ||
|
||
|
||
def save_pid(pid): | ||
with open(PID_FILE, "w") as f: | ||
f.write(str(pid)) | ||
|
||
|
||
def read_pid(): | ||
try: | ||
with open(PID_FILE, "r") as f: | ||
return int(f.read().strip()) | ||
except FileNotFoundError: | ||
return None | ||
|
||
|
||
def remove_pid_file(): | ||
if os.path.exists(PID_FILE): | ||
os.remove(PID_FILE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# DarwinKit Server | ||
|
||
## Develop | ||
- Start the fastapi server | ||
```bash | ||
uvicorn darkit.server.main:app --reload --host 0.0.0.0 | ||
``` | ||
- Start the svelte web | ||
```bash | ||
cd web | ||
npm run dev | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from . import utils | ||
from .trainer import Trainer, TrainerConfig, LogFieldnames | ||
from .predicter import Predicter | ||
|
||
|
||
__all__ = ( | ||
"utils", | ||
"Trainer", | ||
"Predicter", | ||
"TrainerConfig", | ||
"LogFieldnames", | ||
) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import re | ||
import torch | ||
import inspect | ||
|
||
banned_methods = [ | ||
"register_buffer", | ||
] | ||
|
||
|
||
def get_module_functions(module: torch.nn.Module): | ||
""" | ||
Get all the functions of a module. | ||
""" | ||
src_list = [] | ||
for name, func in inspect.getmembers(module, inspect.isfunction): | ||
print("func", name, func) | ||
return src_list | ||
|
||
|
||
def get_module_impl(module): | ||
""" | ||
获取 module 中所有的 self 函数 | ||
从 __init__ 和 forward 开始递归获取模块中涉及的所有函数的信息。 | ||
""" | ||
src_list = [] | ||
src_list += get_module_func_recursive(module, "__init__") | ||
src_list += get_module_func_recursive(module, "forward") | ||
return src_list | ||
|
||
|
||
def get_module_func_recursive(module, func_name) -> list[dict[str, str]]: | ||
""" | ||
获取 module 中指定函数的信息, 然后递归获取该函数中调用的所有 self 函数的信息。 | ||
Args: | ||
module (object): 要分析的模块对象。 | ||
func_name (str): 要获取的函数名称。 | ||
Returns: | ||
list: 包含函数实现信息的字典列表。 | ||
""" | ||
assert hasattr(module, func_name), f"Module does not have function: {func_name}" | ||
src_list = [] | ||
func_info = get_module_func(module, func_name) | ||
src_list.append(func_info) | ||
sub_method_list = extract_member_methods(func_info["body"]) | ||
for sub_method in sub_method_list: | ||
if not hasattr(module, sub_method): | ||
continue | ||
func = getattr(module, sub_method) | ||
if not (callable(func) and inspect.ismethod(func)): | ||
continue | ||
if sub_method in banned_methods: | ||
continue | ||
src_list = src_list + get_module_func_recursive(module, sub_method) | ||
return src_list | ||
|
||
|
||
def get_module_func(module, func_name: str): | ||
""" | ||
获取给定模块中指定函数的函数名称、函数签名和函数主体。 | ||
""" | ||
func_body: str = inspect.getsource(getattr(module, func_name)).strip() | ||
|
||
return {"name": func_name, "body": func_body} | ||
|
||
|
||
def extract_member_methods(func_body): | ||
""" | ||
从给定的函数体中提取调用了的成员方法名称。 | ||
此函数使用正则表达式查找所有出现的 “self.”,后跟提供的函数体字符串中的方法名称。它返回唯一方法名称的列表。 | ||
Args: | ||
func_body (str): The body of the function as a string. | ||
Returns: | ||
list: A list of unique member method names found in the function body. | ||
""" | ||
# Regular expression to match 'self.' followed by a method name | ||
pattern = r"self\.(\w+)\(" | ||
# Find all occurrences of the pattern | ||
matches = re.findall(pattern, func_body) | ||
# Return unique method names | ||
return list(set(matches)) |
Oops, something went wrong.