Skip to content

Commit

Permalink
Feat/swanlab.config (#248)
Browse files Browse the repository at this point in the history
* feat: better config

---------

Co-authored-by: KAAANG <[email protected]>
  • Loading branch information
Zeyi-Lin and SAKURA-CAT authored Jan 22, 2024
1 parent 91d1e85 commit 7956977
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 50 deletions.
1 change: 1 addition & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
init,
log,
finish,
config,
)
from .utils import get_package_version

Expand Down
1 change: 1 addition & 0 deletions swanlab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
init,
log,
finish,
config,
)
2 changes: 1 addition & 1 deletion swanlab/data/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Description:
在此处导出SwanLabRun类,一次实验运行应该只有一个SwanLabRun实例
"""
from .main import SwanLabRun
from .main import SwanLabRun, SwanLabConfig


def register(*args, **kwargs) -> SwanLabRun:
Expand Down
303 changes: 257 additions & 46 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,282 @@
import argparse


class SwanConfig(Mapping):
def need_inited(func):
"""装饰器,用于检查是否已经初始化"""

def wrapper(self, *args, **kwargs):
if not self._inited:
raise RuntimeError("You must call swanlab.init() before using swanlab.log")
return func(self, *args, **kwargs)

return wrapper


class SwanLabConfig(Mapping):
"""
The SwanConfig class is used for realize the invocation method of `run.config.lr`.
"""

def __init__(self, config: dict, settings: SwanDataSettings):
self.__settings = settings
"""配置保存路径,yaml格式"""
self.__config = config
"""config就是外界传入的config,实际上外界访问的就是这个config中的内容"""
self.__save()
# 配置字典
__config = dict()

def __iter__(self):
return iter(self.__config)
# 运行时设置
__settings = dict()

def __len__(self):
return len(self.__config)
@property
def _inited(self):
return self.__settings.get("save_path") is not None

def __init__(self, config: dict, settings: SwanDataSettings = None):
"""
实例化配置类,如果settings不为None,说明是通过swanlab.init调用的,否则是通过swanlab.config调用的
Parameters
----------
settings : SwanDataSettings, optional
运行时设置
"""
self.__config.update(self.__check_config(config))
self.__settings["save_path"] = settings.config_path if settings is not None else None
if self._inited:
self.__save()

def __check_config(self, config: dict) -> dict:
"""
检查配置是否合法,确保它可以被 JSON/YAML 序列化。
如果传入的是 argparse.Namespace 类型,会先转换为字典。
"""
if config is None:
return {}
# config必须可以被json序列化
try:
if isinstance(config, argparse.Namespace):
config = vars(config)
# 将config转换为json序列化的dict
config = json_serializable(dict(config))
# 尝试序列化,如果还是失败就退出
yaml.dump(config)
except:
raise TypeError(f"config: {config} is not a valid dict, which can be json serialized")
return config

def __check_private(self, name: str):
"""
检查属性名是否是私有属性,如果是私有属性,抛出异常
Parameters
----------
name : str
属性名
Raises
----------
AttributeError
如果属性名是私有属性,抛出异常
"""
methods = ["set", "get", "pop"]
swanlog.debug(f"Check private attribute: {name}")
if name.startswith("__") or name.startswith("_SwanLabConfig__") or name in methods:
raise AttributeError("You can not get private attribute")

@need_inited
def __setattr__(self, name: str, value: Any) -> None:
# 只允许修改私有属性,也就是在__init__中定义的属性
if name.startswith("_" + self.__class__.__name__ + "__"):
# 如果是私有属性,允许修改
self.__dict__[name] = value
else:
# 如果不是私有属性,不允许修改
raise AttributeError("SwanConfig object is read-only, attributes cannot be modified")
"""
自定义属性设置方法。如果属性名不是私有属性,则同时更新配置字典并保存。
允许通过点号方式设置属性,但不允许设置私有属性:
```python
run.config.lr = 0.01 # 允许
run.config._lr = 0.01 # 允许
run.config.__lr = 0.01 # 不允许
```
值得注意的是类属性的设置不会触发此方法
"""
# 判断是否是私有属性
self.__check_private(name)
# 设置属性,并判断是否已经初始化,如果是,则调用保存方法
self.__dict__[name] = value
# 同步到配置字典
self.__config[name] = value
self.__save()

@need_inited
def __setitem__(self, name: str, value: Any) -> None:
"""
以字典方式设置配置项的值,并保存,但不允许设置私有属性:
```python
run.config["lr"] = 0.01 # 允许
run.config["_lr"] = 0.01 # 允许
run.config["__lr"] = 0.01 # 不允许
```
"""
# 判断是否是私有属性
self.__check_private(name)
self.__config[name] = value
self.__save()

@need_inited
def set(self, name: str, value: Any) -> None:
"""
Explicitly set the value of a configuration item and save it. For example:
```python
run.config.set("lr", 0.01) # Allowed
run.config.set("_lr", 0.01) # Allowed
run.config.set("__lr", 0.01) # Not allowed
```
Parameters
----------
name: str
Name of the configuration item
value: Any
Value of the configuration item
def __getattr__(self, name):
Raises
----------
AttributeError
If the attribute name is private, an exception is raised
"""
self.__check_private(name)
self.__config[name] = value
self.__save()

@need_inited
def pop(self, name: str) -> bool:
"""
Delete a configuration item; if the item does not exist, skip.
Parameters
----------
name : str
Name of the configuration item
Returns
----------
bool
True if deletion is successful, False otherwise
"""
try:
del self.__config[name]
self.__save()
return True
except KeyError:
return False

@need_inited
def get(self, name: str):
"""
Get the value of a configuration item. If the item does not exist, raise AttributeError.
Parameters
----------
name : str
Name of the configuration item
Returns
----------
value : Any
Value of the configuration item
Raises
----------
AttributeError
If the configuration item does not exist, an AttributeError is raised
"""
try:
return self.__config[name]
except KeyError:
raise AttributeError(f"You have not retrieved '{name}' in the config of the current experiment")

@need_inited
def __getattr__(self, name: str):
"""
如果以点号方式访问属性且属性不存在于类中,尝试从配置字典中获取。
"""
try:
return self.__config[name]
except KeyError:
raise AttributeError(f"You have not get '{name}' in the config of the current experiment")

@need_inited
def __getitem__(self, name: str):
"""
以字典方式获取配置项的值。
"""
try:
return self.__config[name]
except KeyError:
raise AttributeError(f"You have not set {name} in the config of the current experiment")
raise AttributeError(f"You have not get '{name}' in the config of the current experiment")

@need_inited
def __delattr__(self, name: str) -> bool:
"""
删除配置项,如果配置项不存在,跳过
Parameters
----------
name : str
配置项名称
Returns
----------
bool
是否删除成功
"""
try:
del self.__config[name]
return True
except KeyError:
return False

@need_inited
def __delitem__(self, name: str) -> bool:
"""
删除配置项,如果配置项不存在,跳过
Parameters
----------
name : str
配置项名称
def __getitem__(self, name):
return self.__config[name]
Returns
----------
bool
是否删除成功
"""
try:
del self.__config[name]
return True
except KeyError:
return False

def __save(self):
"""
保存config为json,不必校验config的YAML格式,将在写入时完成校验
"""
with get_a_lock(self.__settings.config_path, "w") as f:
swanlog.debug("Save config to {}".format(self.__settings.get("save_path")))
with get_a_lock(self.__settings.get("save_path"), "w") as f:
# 将config的每个key的value转换为desc和value两部分,value就是原来的value,desc是None
# 这样做的目的是为了在web界面中显示config的内容,desc是用于描述value的
config = {key: {"desc": None, "value": value} for key, value in self.__config.items()}
yaml.dump(config, f)

def __iter__(self):
"""
返回配置字典的迭代器。
"""
return iter(self.__config)

def __len__(self):
"""
返回配置项的数量。
"""
return len(self.__config)

def __str__(self):
return str(self.__config)


class SwanLabRun:
"""
Expand Down Expand Up @@ -128,18 +358,18 @@ def __init__(
# 初始化日志等级
level = self.__check_log_level(log_level)
swanlog.setLevel(level)

# ---------------------------------- 初始化配置 ----------------------------------
# 给外部1个config
self.__config = SwanLabConfig(config, self.__settings)
# ---------------------------------- 注册实验 ----------------------------------
# 校验配置格式
config = self.__check_config(config)
# 校验描述格式
description = self.__check_description(description)
self.__exp = self.__register_exp(
experiment_name,
description,
suffix,
)
# 给外部1个config
self.__config = SwanConfig(config, settings=self.__settings)
# 实验状态标记,如果status不为0,则无法再次调用log方法
self.__status = 0

Expand Down Expand Up @@ -316,33 +546,14 @@ def __check_description(self, description: str) -> str:
swanlog.warning("The description has been truncated automatically.")
return desc

def __check_config(self, config: dict) -> dict:
"""检查实验配置是否合法"""
if config is None:
return {}
# config必须可以被json序列化
try:
if isinstance(config, argparse.Namespace):
config = vars(config)
config = json_serializable(dict(config))
check_config = ujson.dumps(config)
except:
raise TypeError(f"config: {config} is not a valid dict, which can be json serialized")
return config

def __record_exp_config(self):
"""创建实验配置目录 files
- 创建 files 目录
- 将实验环境写入 files/swanlab-metadata.json 中
- 将实验依赖写入 files/requirements.txt 中
"""
files_dir = self.__settings.files_dir
requirements_path = self.__settings.requirements_path
metadata_path = self.__settings.metadata_path

# 在实验目录下创建 files 目录,用于存储实验配置信息
if not os.path.exists(files_dir):
os.makedirs(files_dir)
# 将实验依赖存入 requirements.txt
with open(requirements_path, "w") as f:
f.write(get_requirements())
Expand Down
Loading

0 comments on commit 7956977

Please sign in to comment.