diff --git a/.docstr.yaml b/.docstr.yaml
new file mode 100644
index 00000000..2fdfcc9f
--- /dev/null
+++ b/.docstr.yaml
@@ -0,0 +1,21 @@
+paths: ./src/lumo/
+badge: ./images # Path
+exclude:
+
+verbose: 3 # int (0-4)
+skip_magic: True # Boolean
+skip_file_doc: True # Boolean
+skip_property: True
+skip_init: True # Boolean
+#skip_class_def: True # Boolean
+skip_private: True # Boolean
+follow_links: True # Boolean
+accept_empty: True # Boolean
+ignore_names_file: .*/test # regex
+#fail_under: 90 # int
+#percentage_only: True # Boolean
+ignore_patterns:
+ callbacks:
+ - ".*"
+ exphook:
+ - ".*"
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 397f93f8..393a6ab3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -165,4 +165,5 @@ lumo_temp
wandb
.lumorc.json
-docs/
\ No newline at end of file
+.lumorc.public.json
+docs/
diff --git a/README.md b/README.md
index 13fa0003..8d3e8517 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,7 @@
[![PyPI version](https://badge.fury.io/py/lumo.svg)](https://badge.fury.io/py/lumo)
![Python-Test](https://github.com/pytorch-lumo/lumo/actions/workflows/python-test.yml/badge.svg)
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning/blob/master/LICENSE)
+![Python-doc](./images/docstr_coverage_badge.svg)
`lumo` is a light-weight library to help construct your experiment code, record your experiment results, especially in the field of deep learning.
diff --git a/examples/more/screen_str.py b/examples/more/screen_str.py
deleted file mode 100644
index e3f09e22..00000000
--- a/examples/more/screen_str.py
+++ /dev/null
@@ -1,20 +0,0 @@
-"""
-
-"""
-
-import sys
-sys.path.insert(0,"../../")
-import time
-
-from lumo.utils import ScreenStr
-s = "\rLong Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text"
-print(ScreenStr(s,leftoffset=10),end="")
-for i in range(100):
- time.sleep(0.2)
-
-
-from lumo.contrib.data import AutoCollate
-from torch.utils.data.dataloader import DataLoader
-import torch
-device = torch.device('cuda:1')
-DataLoader(...,collate_fn=AutoCollate(device))
diff --git a/examples/quick_start.py b/examples/quick_start.py
index c77e66d9..3a751cd4 100644
--- a/examples/quick_start.py
+++ b/examples/quick_start.py
@@ -18,10 +18,10 @@ def add(x):
db = (
DatasetBuilder()
- .add_input("xs", torch.arange(-2500, 2500, dtype=torch.float).unsqueeze(1))
- .add_input("ys", torch.arange(-2500, 2500, dtype=torch.float), transform=add)
- .add_output("xs", "xs")
- .add_output("ys", "ys")
+ .add_input("xs", torch.arange(-2500, 2500, dtype=torch.float).unsqueeze(1))
+ .add_input("ys", torch.arange(-2500, 2500, dtype=torch.float), transform=add)
+ .add_output("xs", "xs")
+ .add_output("ys", "ys")
)
loader = db.DataLoader(batch_size=params.batch_size, shuffle=True)
@@ -50,12 +50,11 @@ def add(x):
meter.mean.loss = loss
meter.sum.c = 1
record.record(meter)
- logger.inline(record.avg())
+ logger.inline(record.agg())
logger.newline()
test_data = -torch.arange(1000, dtype=torch.float).unsqueeze(1)
res = model(test_data)
-
test_ys = test_data * 2
print((res.long() - test_ys.long()).sum())
diff --git a/images/docstr_coverage_badge.svg b/images/docstr_coverage_badge.svg
new file mode 100644
index 00000000..ede19575
--- /dev/null
+++ b/images/docstr_coverage_badge.svg
@@ -0,0 +1,20 @@
+
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 0322e2f1..e0133509 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,7 @@ omit = [
'src/lumo/cli/*',
'src/lumo/vis/*',
'src/lumo/decorators/*',
+ 'src/lumo/exp/agent.py',
'src/lumo/analyse/*',
'src/lumo/sketch/*',
'src/lumo/core/record_backend/*',
@@ -73,6 +74,24 @@ exclude_lines = [
"if 0:",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
+ # some tricky object that can not well tested.
+ "except ImportError",
+ "pass",
+ "return None",
+ "break", #
+
+ # Hard to test
+ "class RecordAbort",
+ "class GitCommit",
+ "def summary_experiment",
+ "def plot",
+ "if torch.cuda.is_available()", # ignore cuda
+ "if is_dist():", # ignore distribution
+
+ # Deprecated method:
+ "def add_input_transform",
+ "def add_output_transform",
+ "raise StopIteration"
]
[tool.coverage.html]
diff --git a/requirements.txt b/requirements.txt
index 2329b936..d7286f3a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,5 @@ dbrecord
packaging
pandas
hydra-core
-tensorboard
\ No newline at end of file
+tensorboard
+filelock
\ No newline at end of file
diff --git a/setup.py b/setup.py
index f8f61652..a0d9f4c6 100644
--- a/setup.py
+++ b/setup.py
@@ -30,5 +30,8 @@ def extract_version():
keywords='lumo',
packages=find_packages('src'),
entry_points={
+ 'console_scripts': [
+ 'lumo = lumo.cli:main',
+ ]
},
)
diff --git a/src/lumo/__init__.py b/src/lumo/__init__.py
index 957eb07f..5893105d 100644
--- a/src/lumo/__init__.py
+++ b/src/lumo/__init__.py
@@ -1,7 +1,7 @@
"""
"""
-__version__ = "0.14.6"
+__version__ = "0.15.0"
from .core import Params, ParamsType, MetricType, Meter, Record, TrainStage, BaseParams
diff --git a/src/lumo/analyse/collect.py b/src/lumo/analyse/collect.py
index 24193827..7cd89207 100644
--- a/src/lumo/analyse/collect.py
+++ b/src/lumo/analyse/collect.py
@@ -68,14 +68,14 @@ def flatten_metric(df, *keys: str):
def collect_table_rows(metric_root=None) -> pd.DataFrame:
"""Collect all table_row into a pandas.DataFrame"""
- res = []
+ res = {}
logger = Logger()
exp_map = list_all_metrics(metric_root)
for k, rows in exp_map.items():
# append existing row metrics
global_dic = PDict(os.path.join(metricroot(), f'{k}.dict.sqlite'))
- for row in global_dic.values():
- res.append(row)
+ for test_name, row in global_dic.items():
+ res[test_name] = row
if len(rows) == 0:
continue
@@ -94,10 +94,10 @@ def collect_table_rows(metric_root=None) -> pd.DataFrame:
continue
global_dic[test_name] = row
shutil.move(row_fn, os.path.join(os.path.dirname(row_fn), f'.{test_name}.pkl'))
- res.append(row)
+ res[test_name] = row
global_dic.flush()
- return pd.DataFrame(res)
+ return pd.DataFrame(list(res.values()))
def replac(df: pd.DataFrame):
diff --git a/src/lumo/analyse/condition.py b/src/lumo/analyse/condition.py
index 10d21d45..3f9a2292 100644
--- a/src/lumo/analyse/condition.py
+++ b/src/lumo/analyse/condition.py
@@ -7,14 +7,17 @@
def in_(ser, value):
+ """pandas operation"""
return ser.apply(lambda x: x in value)
def not_in_(ser, value):
+ """pandas operation"""
return ser.apply(lambda x: x not in value)
def first(ser, value):
+ """pandas operation"""
return ser.duplicated(value) == False
@@ -101,16 +104,19 @@ def __repr__(self):
return f'{self.name} {self.op} {self.value}'
def in_(self, lis):
+ """condition of `in` operation"""
self.op = 'in'
self.value = set(lis)
return self
def not_in_(self, lis):
+ """condition of `.duplicated(value) == False` operation"""
self.op = 'notin'
self.value = set(lis)
return self
def first(self, value):
+ """condition of `not in` operation"""
self.op = 'first'
self.value = value
return self
@@ -121,7 +127,8 @@ def first(self, value):
def filter_by_condition(df: DataFrame, *condition: Compare) -> DataFrame:
"""
- 简化版的 pipeline,仅用于等式和不等式筛选以及额外支持的 in/not_in 等功能
+ fdf = filter_by_condition(df,C['contition'] == True, C['c2'] > 1, ... )
+
Args:
df: padnas.DataFrame instance
*condition: list of `~Compare` instance
diff --git a/src/lumo/cli/__init__.py b/src/lumo/cli/__init__.py
new file mode 100644
index 00000000..714545ea
--- /dev/null
+++ b/src/lumo/cli/__init__.py
@@ -0,0 +1,56 @@
+import fire
+
+
+def rerun(test_name, **kwarg):
+ """
+ rerun a test
+ lumo rerun
+ lumo rerun --device=0
+
+ Args:
+ test_name:
+
+ Returns:
+
+ """
+ from lumo.exp.finder import retrieval_experiment
+ exp = retrieval_experiment(test_name)
+ if exp is not None:
+ exp.rerun([f'--{k}={v}' for k, v in kwarg.items()])
+ else:
+ exit(1)
+
+
+def note(test_name, description):
+ """
+ Add note to a test:
+ lumo note description ;
+
+ Args:
+ test_name:
+ description:
+
+ Returns:
+
+ """
+ print(f"Adding note '{description}' to {test_name}")
+
+
+def server(port=8080):
+ """
+
+ Args:
+ port:
+
+ Returns:
+
+ """
+ print(f"Starting server on port {port}")
+
+
+def main():
+ fire.Fire({
+ 'rerun': rerun,
+ 'note': note,
+ 'server': server,
+ })
diff --git a/src/lumo/cli/__main__.py b/src/lumo/cli/__main__.py
new file mode 100644
index 00000000..b678b238
--- /dev/null
+++ b/src/lumo/cli/__main__.py
@@ -0,0 +1,5 @@
+import fire
+from lumo.cli import main
+
+if __name__ == '__main__':
+ main()
diff --git a/src/lumo/contrib/module/memoty_bank.py b/src/lumo/contrib/module/memoty_bank.py
index f3da0c58..d080c87e 100644
--- a/src/lumo/contrib/module/memoty_bank.py
+++ b/src/lumo/contrib/module/memoty_bank.py
@@ -1,25 +1,56 @@
-import torch
from accelerate.utils import gather
+from lumo.proc.dist import is_dist
+import torch.distributed
from torch import nn
+import torch
+
+
+class StorageBank(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.offset = 0
+ self.sizes = {}
+
+ def register(self, name, dim, k, dtype=None):
+ if dim <= 0:
+ bank = (torch.ones(k) + float('inf')).to(dtype=dtype)
+ else:
+ bank = (torch.ones(k, dim) + float('inf')).to(dtype=dtype)
+
+ self.register_buffer(name, bank)
+ self.sizes[name] = bank.shape[0]
+
+ def __getitem__(self, item):
+ return self.__getattr__(item)
+
+ @torch.no_grad()
+ def scatter(self, name, value, index):
+ value = value.detach()
+ value = gather(value)
+ if isinstance(index, torch.Tensor):
+ index = gather(index)
+ self[name][index] = value
class MemoryBank(nn.Module):
- def __init__(self, k):
+ def __init__(self):
super().__init__()
self.offset = 0
- self.k = k
self.offsets = {}
self.sizes = {}
- def register(self, name, dim):
+ def register(self, name, dim, k, dtype=None):
if dim <= 0:
- bank = torch.rand(self.k)
+ bank = torch.rand(k, dtype=dtype)
else:
- bank = torch.rand(self.k, dim)
+ bank = torch.rand(k, dim, dtype=dtype)
self.register_buffer(name, bank)
self.offsets[name] = 0
- self.sizes[name] = bank.shape[0]
+ self.sizes[name] = k
+
+ def __setitem__(self, key, value):
+ self.__setattr__(key, value)
def __getitem__(self, item):
return self.__getattr__(item)
@@ -39,3 +70,56 @@ def push(self, name, value):
value = value[:batch_size]
self[name][ptr:ptr + batch_size] = value
self.offsets[name] = (ptr + batch_size) % k
+
+
+@torch.no_grad()
+def batch_shuffle_ddp(x):
+ """
+ Batch shuffle, for making use of BatchNorm.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ """
+ if not is_dist():
+ return x, torch.arange(len(x))
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # random shuffle index
+ idx_shuffle = torch.randperm(batch_size_all).cuda()
+
+ # broadcast to all gpus
+ torch.distributed.broadcast(idx_shuffle, src=0)
+
+ # index for restoring
+ idx_unshuffle = torch.argsort(idx_shuffle)
+
+ # shuffled index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this], idx_unshuffle
+
+
+@torch.no_grad()
+def batch_unshuffle_ddp(x, idx_unshuffle):
+ """
+ Undo batch shuffle.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ """
+ if not is_dist():
+ return x
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # restored index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this]
diff --git a/src/lumo/core/README.md b/src/lumo/core/README.md
deleted file mode 100644
index 9e42711c..00000000
--- a/src/lumo/core/README.md
+++ /dev/null
@@ -1,10 +0,0 @@
-
- - 数据结构
- - Meter,指标记录
- - Params,参数管理
- - Scheduler
- - Logger,日志管理
- - 全局变量、proc
- - 线程管理
- - 装饰器管理
- - device 管理
\ No newline at end of file
diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py
index 3c0bf610..558e90c5 100644
--- a/src/lumo/core/attr.py
+++ b/src/lumo/core/attr.py
@@ -6,6 +6,12 @@
class Attr(OrderedDict):
+ """
+ A subclass of OrderedDict that allows you to access its elements via dot notation.
+
+ This class overrides the __setattr__, __setitem__, __getattr__, and __getitem__ methods to provide the
+ dot notation functionality.
+ """
def __setattr__(self, key: str, value):
set_item_iterative(self, key.split('.'), value)
@@ -30,22 +36,57 @@ def __getitem__(self, key):
return get_item_iterative(self, key.split('.'))
-def safe_update_dict(src: dict, kwargs: dict, assert_type=True):
+def safe_update_dict(src: dict, kwargs: dict, assert_type=False):
+ """
+ Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner.
+
+ This function iterates over the items in the kwargs dictionary and updates the corresponding items in the
+ source dictionary, making sure that the types of the values being updated match the types of the values
+ already in the source dictionary.
+
+ Args:
+ src (dict): The dictionary to update.
+ kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary.
+ assert_type (bool): A flag indicating whether to check that the types of the values being updated match
+ the types of the values already in the source dictionary. Defaults to True.
+
+ Returns:
+ dict: The updated source dictionary.
+ """
for ks, v in walk_dict(kwargs):
try:
old_v = get_item_iterative(src, ks)
- if old_v is None or isinstance(old_v, type(v)):
+ if old_v is None or isinstance(old_v, type(v)) or not assert_type:
set_item_iterative(src, ks, v)
- # print(ks, v)
else:
raise TypeError(ks, type(old_v), type(v))
except KeyError:
set_item_iterative(src, ks, v)
- # print(ks, v)
return src
def walk_dict(dic: dict, root=None):
+ """
+ Recursively walks through a dictionary and yields keys and values in a flattened format.
+
+ Args:
+ - dic (dict): The dictionary to be walked through.
+ - root (list): The root keys to be used in the resulting flattened format. Defaults to None.
+
+ Yields:
+ - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary.
+
+ Example:
+ ```python
+ d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
+ for k, v in walk_dict(d):
+ print(k, v)
+ # Output:
+ # (['a', 'b'], 1)
+ # (['a', 'c', 'd'], 2)
+ # (['e'], 3)
+ ```
+ """
if root is None:
root = []
for k, v in dic.items():
@@ -56,30 +97,56 @@ def walk_dict(dic: dict, root=None):
def set_item_iterative(dic: dict, keys: List[str], value):
+ """
+ Sets the value of a nested key in a dictionary using an iterative approach.
+
+ Args:
+ dic (dict): The dictionary to update.
+ keys (List[str]): A list of keys representing the path to the nested key in the dictionary.
+ value: The value to set for the nested key.
+
+ Raises:
+ ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary.
+
+ """
if len(keys) == 1:
if isinstance(value, dict):
for ks, v in walk_dict(value):
set_item_iterative(dic, [*keys, *ks], v)
else:
- dict.__setitem__(dic, keys[0], value)
+ OrderedDict.__setitem__(dic, keys[0], value)
else:
try:
- nex = dict.__getitem__(dic, keys[0])
+ nex = OrderedDict.__getitem__(dic, keys[0])
if not isinstance(nex, dict):
raise ValueError(keys[0], nex)
# dict.__setitem__(dic, keys[0], nex)
except KeyError:
- nex = dict()
- dict.__setitem__(dic, keys[0], nex)
+ nex = Attr()
+ OrderedDict.__setitem__(dic, keys[0], nex)
set_item_iterative(nex, keys[1:], value)
def get_item_iterative(dic: dict, keys: List[str]):
+ """
+ Gets the value of a nested key in a dictionary using an iterative approach.
+
+ Args:
+ dic (dict): The dictionary to retrieve the value from.
+ keys (List[str]): A list of keys representing the path to the nested key in the dictionary.
+
+ Raises:
+ KeyError: If the nested key does not exist in the dictionary.
+
+ Returns:
+ The value of the nested key in the dictionary.
+
+ """
if len(keys) == 1:
- return dict.__getitem__(dic, keys[0])
+ return OrderedDict.__getitem__(dic, keys[0])
else:
- nex = dict.__getitem__(dic, keys[0])
+ nex = OrderedDict.__getitem__(dic, keys[0])
if isinstance(nex, dict):
return get_item_iterative(nex, keys[1:])
else:
diff --git a/src/lumo/core/disk.py b/src/lumo/core/disk.py
index 5f675d7c..baf19428 100644
--- a/src/lumo/core/disk.py
+++ b/src/lumo/core/disk.py
@@ -1,22 +1,53 @@
import os.path
+import warnings
+
from dbrecord import PList
+
from lumo.proc import path
-from lumo.utils.filelock2 import Lock
from lumo.utils import safe_io as IO
+from lumo.decorators.deprecated import DeprecatedWarning
class Metrics:
"""
- Record metrics at multiple steps. Supported by dbrecord.
+ Records metrics at multiple steps and stages. The metrics are supported by dbrecord.
+
+ Args:
+ test_path (str): The path to the test directory.
+
+ Attributes:
+ fpath (str): The path to the metric board SQLite file.
+ disk (PList): The PList instance for accessing the metric board SQLite file.
+
+ Methods:
+ append(metric: dict, step: int, stage: str = 'train') -> None:
+ Adds the specified metric, step, and stage to the metric board SQLite file.
+ The metric is a dictionary object that contains the metric name as the key and the metric value as the value.
+ The stage is either 'train' or 'test', and it is set to 'train' by default.
+ This method calls the `flush` method to write the changes to disk.
+
+ flush() -> None:
+ Writes any changes to the metric board SQLite file to disk.
"""
- def __init__(self, test_path: str):
+ def __init__(self, test_path: str, persistent=True):
os.makedirs(test_path, exist_ok=True)
self.fpath = os.path.join(test_path, f'metric_board.sqlite')
self.disk = PList(self.fpath)
- self.lock = Lock(os.path.basename(test_path.rstrip('/')))
+ self.persistent = persistent
def append(self, metric: dict, step, stage='train'):
+ """
+ Adds the specified metric, step, and stage to the metric board SQLite file.
+
+ Args:
+ metric (dict): A dictionary object that contains the metric name as the key and the metric value as the value.
+ step (int): The step number.
+ stage (str, optional): The stage of the metric, either 'train' or 'test'. Defaults to 'train'.
+
+ Returns:
+ None
+ """
self.disk.append({
'metric': metric,
'step': step,
@@ -25,35 +56,84 @@ def append(self, metric: dict, step, stage='train'):
self.disk.flush()
def flush(self):
- self.disk.flush()
+ """
+ Writes any changes to the metric board SQLite file to disk.
+
+ Returns:
+ None
+ """
+ if self.persistent:
+ self.disk.flush()
class TableRow:
"""
- It can be regarded as a serialized dictionary,
- or a certain row in the table, so the same key value will be overwritten.
-
- If you need to record records at different times, please use trainer.metrics
+ TableRow class is a serialized dictionary that can represents a single row in a table.
+ If the same key is updated, its value will be overwritten.
+ Please use trainer.metrics to record records at different times.
+
+ Args:
+ - table (str): name of the table to which the row belongs.
+ - partition (str): partition of the row.
+ - rowkey (str): unique identifier of the row.
+
+ Attributes:
+ - fpath (str): path of the file that stores the serialized row.
+ - key (str): unique identifier of the row.
+ - value (dict): dictionary representing the row.
+ - persistent (bool): whether to store in disk.
+
+ Methods:
+ - __enter__(self): context manager method. Does nothing.
+ - __exit__(self, exc_type, exc_val, exc_tb): context manager method. Calls flush method.
+ - flush(self): writes the value of the row to a file.
+ - update_metrics(self, dic: dict, compare=None, flush=False): updates multiple metrics in the row.
+ - update_metric(self, key, value, compare=None, flush=False): updates a single metric in the row.
+ - metric(self): returns the metric dictionary of the row.
+ - update_metric_pair(self, key, value, key2, value2, compare=None, flush=False): updates two metrics in the row.
+ - set_params(self, params: dict): sets the value of 'params' key in the row.
+ - update_dict(self, dic: dict, flush=False): updates multiple keys in the row.
+ - update(self, key, value, flush=True): updates a single key in the row.
+ - __getitem__(self, item): returns the value of a key in the row.
"""
- def __init__(self, table, partition, rowkey):
- dirpath = os.path.join(path.metricroot(), table)
- os.makedirs(dirpath, exist_ok=True)
- self.fpath = os.path.join(dirpath, partition, f'{rowkey}.pkl')
- self.key = rowkey
+ def __init__(self, fn, persistent=True):
+ os.makedirs(os.path.dirname(os.path.abspath(fn)), exist_ok=True)
+ self.fpath = fn
self.value = {}
+ self.persistent = persistent
+
# self.disk = PDict(self.fpath)
def __enter__(self):
+ """
+ Does nothing.
+ """
pass
def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Calls flush method. Required for using the object as a context manager.
+ """
self.flush()
def flush(self):
- IO.dump_pkl(self.value, self.fpath)
+ """Writes the value of the row to a file."""
+ if self.persistent:
+ IO.dump_pkl(self.value, self.fpath)
def update_metrics(self, dic: dict, compare=None, flush=False):
+ """
+ Updates multiple metrics in the row.
+
+ Args:
+ - dic (dict): dictionary containing key-value pairs to be updated.
+ - compare (str): comparison operator to be used for updating metrics. Only 'max' and 'min' are supported.
+ - flush (bool): if True, writes the value of the row to a file after updating the metrics.
+
+ Returns:
+ - res (dict): dictionary containing key-value pairs that were updated.
+ """
res = {}
[res.update(self.update_metric(k, v, compare)) for k, v in dic.items()]
@@ -62,6 +142,19 @@ def update_metrics(self, dic: dict, compare=None, flush=False):
return res
def update_metric(self, key, value, compare=None, flush=False):
+ """
+ Updates a metric value in the row.
+
+ Args:
+ key (str): The key of the metric.
+ value (float): The value of the metric.
+ compare (str, optional): The comparison operator used to compare the new value with the old one.
+ Either 'max' or 'min'. Default is None.
+ flush (bool, optional): Whether to flush the changes to disk. Default is False.
+
+ Returns:
+ dict: A dictionary containing the updated metric key and value.
+ """
dic = self.metric
older = dic.setdefault(key, None)
@@ -89,9 +182,30 @@ def update_metric(self, key, value, compare=None, flush=False):
@property
def metric(self):
+ """
+ A property that returns the metric values of the row.
+
+ Returns:
+ dict: A dictionary containing the metric values of the row.
+ """
return self.value.setdefault('metric', {})
def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False):
+ """
+ Update a pair of key-value metrics in the metric dictionary.
+
+ Args:
+ key (str): The key of the first metric.
+ value (float): The value of the first metric.
+ key2 (str): The key of the second metric.
+ value2 (float): The value of the second metric.
+ compare (str, optional): The method to compare values. Default is None.
+ Possible values are 'max', 'min'.
+ flush (bool, optional): Whether to flush to disk after updating. Default is False.
+
+ Returns:
+ dict: A dictionary with the old values of the updated metrics.
+ """
dic = self.metric
old = dic.setdefault(key, None)
old2 = dic.setdefault(key2, None)
@@ -119,21 +233,11 @@ def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False
return {key: old, key2: old2}
- def set_params(self, params: dict):
- self.value['params'] = params
- self.flush()
- return params
-
- def update_dict(self, dic: dict, flush=False):
- for k, v in dic.items():
- self.update(k, v)
- if flush:
- self.flush()
-
- def update(self, key, value, flush=True):
- self.value[key] = value
- if flush:
- self.flush()
-
def __getitem__(self, item):
+ """Get the value of a key in the row."""
return self.value[item]
+
+
+DeprecatedWarning(TableRow, '0.15.0', '1.0.0',
+ 'This class is deprecated and will be remove in 1.0.0, '
+ 'Please use Experiment.metric to record your best metric.')
diff --git a/src/lumo/core/enums.py b/src/lumo/core/enums.py
index 3c931c90..aa6dabb1 100644
--- a/src/lumo/core/enums.py
+++ b/src/lumo/core/enums.py
@@ -2,22 +2,49 @@
class TrainStage(enum.Enum):
+ """An enumeration class representing the different stages of training.
+ """
default = 'default'
train = 'train'
test = 'test'
val = 'val'
def is_train(self):
+ """Check if the current stage is the training stage.
+
+ Returns:
+ bool: True if the current stage is the training stage, False otherwise.
+ """
return self.value == 'train'
def is_test(self):
+ """Check if the current stage is the testing stage.
+
+ Returns:
+ bool: True if the current stage is the testing stage, False otherwise.
+ """
return self.value == 'test'
def is_val(self):
+ """Check if the current stage is the validation stage.
+
+ Returns:
+ bool: True if the current stage is the validation stage, False otherwise.
+ """
return self.value == 'val'
@staticmethod
def create_from_str(value):
+ """Create a TrainStage instance from a string.
+
+ If the value is 'eval' or 'evaluate', it will be converted to 'val'.
+
+ Args:
+ value (str): A string representing the stage of training.
+
+ Returns:
+ TrainStage: A TrainStage instance representing the stage of training.
+ """
if value in {'eval', 'evaluate'}:
value = 'val'
return TrainStage(value)
diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py
index 100f4454..ad7a2c1a 100644
--- a/src/lumo/core/interp.py
+++ b/src/lumo/core/interp.py
@@ -36,46 +36,55 @@
'PeriodTriangle',
'PeriodLinear',
'PowerDecay',
+ 'PowerDecay2',
'InterpolateList', ]
class Interpolate(BaseParams):
- """
- ratio 变化为从 1 - 0
- """
-
- def toggle_constant(self, toggle=True):
- """fix the schedule as the first value"""
- self.constant = toggle
- return self
+ """A class for implementing interpolation schedule of a learning rate."""
@classmethod
def interp(self, *args, **kwargs):
+ """Interpolation method for the schedule. Must be implemented in a subclass."""
raise NotImplementedError()
def __repr__(self):
+ """Return a string representation of the schedule."""
content = ', '.join(["{}={}".format(k, v) for k, v in self.items()])
return "{}({})".format(self.__class__.__name__, content)
def __call__(self, cur):
+ """Return the learning rate at the given step 'cur'."""
raise NotImplementedError()
def get(self, key: DictKeyType, default_value: Any = None) -> Any:
+ """
+ Return the value of the key from the schedule's dictionary, or default_value if the key is not present.
+
+ Args:
+ key: The key of the dictionary.
+ default_value: The default value to return if the key is not found.
+
+ Returns:
+ The value of the key, or default_value if the key is not found.
+ """
return self(key)
def plot(self, num=1000, left=0, right=1000, show=True):
"""
Plot a curve of the schedule.
+
Args:
- num:
- left:
- right:
+ num: The number of points to plot the curve.
+ left: The starting point of the curve.
+ right: The ending point of the curve.
+ show: Whether to display the plot or not.
Returns:
- plt.plot
+ The plot object.
Notes:
- You may need to call `plt.show`() to show it.
+ You may need to call `plt.show()` to show the plot.
"""
from matplotlib import pyplot as plt
@@ -90,21 +99,22 @@ def plot(self, num=1000, left=0, right=1000, show=True):
def scale(self, optimizer, cur):
"""
- Scale the learning rate by current value. 'scale' means not apply the current schedule value directly to
- the learning rate, but multiple the initial learning rate. You can use `schedule.apply()` to apply the schedule
- value directly.
+ Scale the learning rate by the current value.
+
+ 'Scale' means that the current schedule value will not be applied directly to the learning rate, but will be
+ multiplied by the initial learning rate. You can use `schedule.apply()` to apply the schedule value directly.
Notes:
- -------
- When first apply scale function, a `_raw_lr` represent initial lr will be set in each param_group, then, the
- learning rate(store in param_groups with key 'lr') will be calculated by `_raw_lr * schedule(cur)`.
+ When `scale()` is first called, an initial learning rate `_raw_lr` is stored in each `param_group`. Then,
+ the learning rate (stored in `param_groups` with the key 'lr') will be calculated as `_raw_lr *
+ schedule(cur)`.
Args:
- optimizer: A pytorch optimizer instance.
- cur: current step of this schedule.
+ optimizer: A PyTorch optimizer instance.
+ cur: The current step of the schedule.
Returns:
- Current schedule value.
+ The current schedule value.
"""
ratio = self(cur)
for param_group in optimizer.param_groups: # type:dict
@@ -115,14 +125,14 @@ def scale(self, optimizer, cur):
def apply(self, optimizer, cur):
"""
- Apply the learning rate with current schedule value.
+ Apply the learning rate with the current schedule value.
Args:
- optimizer: A pytorch optimizer instance.
- cur: current step of this schedule.
+ optimizer: A PyTorch optimizer instance.
+ cur: The current step of the schedule.
Returns:
-
+ The new learning rate.
"""
new_lr = self(cur)
for param_group in optimizer.param_groups: # type:dict
@@ -132,14 +142,44 @@ def apply(self, optimizer, cur):
class ABCContinuous(Interpolate):
+ """
+ Interpolates a continuous schedule for a value between a start and end point.
+
+ Args:
+ start (float): The starting value of the schedule.
+ end (float): The ending value of the schedule.
+ left (float): The left boundary of the range of values to interpolate over.
+ right (float): The right boundary of the range of values to interpolate over.
+ *args: Additional arguments to pass to the superclass constructor.
+ **kwargs: Additional keyword arguments to pass to the superclass constructor.
+
+ Attributes:
+ left (float): The left boundary of the range of values to interpolate over.
+ right (float): The right boundary of the range of values to interpolate over.
+ start (float): The starting value of the schedule.
+ end (float): The ending value of the schedule.
+ constant (bool): A flag indicating whether the schedule is constant.
+ """
def __call__(self, cur):
+ """Returns the interpolated value of the schedule at a given point."""
if self.constant:
return self.start
return self.interp(cur, start=self.start, end=self.end, left=self.left, right=self.right,
constant=self.constant)
def __init__(self, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs):
+ """
+ Initializes an instance of ABCContinuous.
+
+ Args:
+ start (float): The starting value of the schedule.
+ end (float): The ending value of the schedule.
+ left (float): The left boundary of the range of values to interpolate over.
+ right (float): The right boundary of the range of values to interpolate over.
+ *args: Additional arguments to pass to the superclass constructor.
+ **kwargs: Additional keyword arguments to pass to the superclass constructor.
+ """
super().__init__()
self.left = left
self.right = right
@@ -149,6 +189,7 @@ def __init__(self, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs):
@classmethod
def ratio(cls, cur, left, right, constant=False):
+ """Returns the ratio of a given point between the left and right boundaries."""
if constant:
return 0
return (cur - left) / (right - left)
@@ -159,6 +200,7 @@ def get_val(cls, cur, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs):
return cls.interp(cur=cur, start=start, end=end, left=left, right=right, *args, **kwargs)
def plot(self, num=1000, left=None, right=None, show=True):
+ """Plots the interpolated schedule."""
if left is None:
left = self.left
@@ -173,10 +215,30 @@ def plot(self, num=1000, left=None, right=None, show=True):
class ABCPeriod(Interpolate):
"""
- period
+ A class for generating schedules with a repeating period.
+
+ Attributes:
+ left (float): The left boundary of the schedule.
+ period (float): The period of the schedule.
+ start (float): The start value of the schedule.
+ end (float): The end value of the schedule.
+ constant (bool): A flag indicating if the schedule is constant.
+
"""
def __init__(self, start=0, end=1, left=0, period=1, *args, **kwargs):
+ """
+ Initializes an instance of the `ABCPeriod` class.
+
+ Args:
+ start (float): The start value of the schedule.
+ end (float): The end value of the schedule.
+ left (float): The left boundary of the schedule.
+ period (float): The period of the schedule.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+
+ """
super().__init__()
self.left = left
self.period = period
@@ -186,6 +248,7 @@ def __init__(self, start=0, end=1, left=0, period=1, *args, **kwargs):
@classmethod
def ratio(cls, cur, left, period, constant=False):
+ """Returns the ratio of time elapsed in the current period."""
if constant:
return 0
if cur < left:
@@ -201,6 +264,7 @@ def get_val(cls, cur, start=0, end=1, left=0, right=1, *args, **kwargs):
left=left, right=right, *args, **kwargs)
def plot(self, num=1000, left=None, n_period=5, show=True):
+ """Plots the schedule between the specified boundaries."""
if left is None:
left = self.left
@@ -211,15 +275,25 @@ def plot(self, num=1000, left=None, n_period=5, show=True):
return super().plot(num, left, right, show)
def __call__(self, cur):
+ """Returns the current schedule value at the given time."""
return self.interp(cur, start=self.start, end=self.end, left=self.left, period=self.period,
constant=self.constant)
class Cos(ABCContinuous):
- """one cycle cosine functoin"""
+ """one cycle cosine functoin
+
+ end -> ,---------
+ /
+ start -> ______/
+ ↑ ↑
+ left right
+
+ """
@classmethod
def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
if constant:
return start
@@ -234,12 +308,21 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
return start * cos_ratio + end * (1 - cos_ratio)
-
class Linear(ABCContinuous):
- """linear schedule"""
+ """linear schedule
+
+ ^
+ end | .*
+ | .*
+ | .*
+ |.*
+ start +----------------->
+ left right
+ """
@classmethod
def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
if constant:
return start
@@ -258,6 +341,7 @@ class Exp(ABCContinuous):
@classmethod
def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
if constant:
return start
@@ -275,10 +359,24 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
class Log(ABCContinuous):
- """quick to slow"""
+ """
+ quick to slow
+
+ end | *
+ | *
+ | *
+ | *
+ | *
+ | *
+ | *
+ start |*
+ -------------------------------------------
+ left right
+ """
@classmethod
def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
if constant:
return start
@@ -297,6 +395,15 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs):
class Constant(ABCContinuous):
+ """
+ A scheduler representing a constant value
+ |
+ constant |--------------
+ |
+ |________________
+ ... any ...
+ """
+
def __init__(self, value=0.5, *args, **kwargs):
super().__init__(start=value, end=value, left=0, right=1, *args, **kwargs)
self.constant = True
@@ -305,10 +412,18 @@ def __init__(self, value=0.5, *args, **kwargs):
class PeriodCos(ABCPeriod):
"""
periodic cosine schedule
+
+ end -> ,-. ,-. ,-. ,-.
+ / \ / \ / \ / \
+ start -> ______/ \_/ \_/ \_/ \_________
+ ratio 0 1 2 3 .....
+ \----|
+ period
"""
@classmethod
def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
ratio = cls.ratio(cur, left=left, period=period, constant=constant)
@@ -319,10 +434,18 @@ def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs):
class PeriodHalfCos(ABCPeriod):
"""
half periodic cosine schedule, period is (right-left)
+
+ end -> ,- ,- ,- ,-
+ / / / /
+ start -> ______/ / / /
+ ratio 0 1 2 3 ...
+ \--|
+ period
"""
@classmethod
def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs):
+ """interp with period halfcos method"""
constant = kwargs.get('constant', False)
ratio = cls.ratio(cur, left=left, period=period, constant=constant)
cos_ratio = 0.5 * (1 + np.cos(ratio * np.pi))
@@ -330,6 +453,17 @@ def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs):
class PeriodTriangle(ABCPeriod):
+ """
+ A interp class to simulate a periodic triangle waveform.
+
+
+ end /\ /\ /\ /\
+ start / \/ \/ \/ \
+ 0 1 2 3 4
+ \--|
+ period
+ """
+
def __init__(self, start=0, end=1, left=0, left_period=1, right_period=1, *args, **kwargs):
super().__init__(start=start, end=end, left=left,
period=(left_period + right_period),
@@ -340,6 +474,7 @@ def __init__(self, start=0, end=1, left=0, left_period=1, right_period=1, *args,
@classmethod
def interp(cls, cur, start=0., end=1., left=0., left_period=0., right_period=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
ratio = cls.ratio(cur, left=left, period=(left_period + right_period), constant=constant)
@@ -356,10 +491,17 @@ def interp(cls, cur, start=0., end=1., left=0., left_period=0., right_period=1.,
class PeriodLinear(ABCPeriod):
"""
sawtooth wave, like a period line schedule
+ end / / / /
+ / / / /
+ start / / / /
+ 0 1 2 3 ....
+ \----|
+ period
"""
@classmethod
def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs):
+ """Interpolation method for the schedule."""
constant = kwargs.get('constant', False)
ratio = cls.ratio(cur, left=left, period=period, constant=constant)
return start * (1 - ratio) + end * ratio
@@ -385,6 +527,8 @@ def __call__(self, cur):
class PowerDecay2(Interpolate):
+ """A class for implementing Power Decay Interpolation for a given schedule."""
+
def __init__(self, start, schedules, gammas):
super().__init__()
self.start = start
@@ -393,6 +537,7 @@ def __init__(self, start, schedules, gammas):
@classmethod
def interp(cls, cur, start=0., gammas=None, schedules=None, *args, **kwargs):
+ """Interpolation method for the schedule."""
if schedules is None:
schedules = []
if gammas is None:
@@ -407,10 +552,12 @@ def interp(cls, cur, start=0., gammas=None, schedules=None, *args, **kwargs):
return res
def __call__(self, cur):
- self.interp(cur, start=self.start, gammas=self.gammas, schedules=self.schedules)
+ return self.interp(cur, start=self.start, gammas=self.gammas, schedules=self.schedules)
class InterpolateList(Interpolate):
+ """Concat different interpolation functions"""
+
def __init__(self, schedules: List[Interpolate]):
super().__init__()
self.schedules = schedules
@@ -441,6 +588,7 @@ def __repr__(self):
return '{}({})'.format(self.__class__.__name__, content)
def plot(self, num=1000, left=None, right=None, show=True):
+ """plot"""
if left is None:
left = self.left
diff --git a/src/lumo/core/meter.py b/src/lumo/core/meter.py
index a8ef163f..44bc21bd 100644
--- a/src/lumo/core/meter.py
+++ b/src/lumo/core/meter.py
@@ -4,7 +4,7 @@
from collections import OrderedDict
from collections.abc import ItemsView
from numbers import Number
-from typing import Union, Iterator, Tuple, Mapping, Sequence
+from typing import Iterator, Tuple, Mapping
import numpy as np
import torch
@@ -13,21 +13,59 @@
class Meter:
+ """
+ A class for recording and managing metrics.
+
+ Attributes:
+ _prop (dict): A dictionary to store properties of the meter.
+ _rec (OrderedDict): An ordered dictionary to record the metrics and their values.
+ _avg (dict): A dictionary to store the aggregation method for each metric.
+
+ Methods:
+ sorted() -> 'Meter': Returns a new meter with the metrics sorted by their names.
+ todict() -> OrderedDict: Returns the recorded metrics as an ordered dictionary.
+ update(dic: Mapping) -> 'Meter': Updates the meter with the given dictionary of metrics.
+ serialize() -> OrderedDict: Returns a dictionary representation of the meter.
+ items() -> ItemsView: Returns a view object containing the (metric, value) pairs.
+ keys() -> KeysView: Returns a view object containing the metric names.
+ scalar_items() -> Iterator[Tuple[str, Number]]: Returns an iterator over the (metric, value) pairs with scalar values.
+
+ Properties:
+ sum: Sets the aggregation method to 'sum'.
+ mean: Sets the aggregation method to 'mean'.
+ last: Sets the aggregation method to 'last'.
+ max: Sets the aggregation method to 'max'.
+ min: Sets the aggregation method to 'min'.
+ smean: Sets the aggregation method to 'smean'.
+ """
+
def __init__(self):
self._prop = {}
- self._rec = {}
+ self._rec = OrderedDict()
self._avg = {}
def sorted(self) -> 'Meter':
+ """
+ Returns a new meter with the metrics sorted by their names.
+
+ Returns:
+ A new meter with the metrics sorted by their names.
+ """
m = Meter()
- m._rec = self._rec
- m._avg = OrderedDict()
+
m._prop = self._prop
for k in sorted(self._avg.keys()):
m._avg[k] = self._avg[k]
+ m._rec[k] = self._rec[k]
return m
def todict(self):
+ """
+ Returns the recorded metrics as an ordered dictionary.
+
+ Returns:
+ An ordered dictionary containing the recorded metrics and their values.
+ """
return self._rec
@property
@@ -39,18 +77,50 @@ def _stage(self, value):
self._prop['stage'] = value
def __setattr__(self, key: str, value):
+ """
+ Sets the value of an attribute.
+
+ Args:
+ key (str): The name of the attribute.
+ value: The value to set the attribute to.
+ """
if key.startswith('_'):
super(Meter, self).__setattr__(key, value)
else:
self[key] = value
def __getattr__(self, item):
+ """
+ Returns the value of a metric.
+
+ Args:
+ item: The name of the metric.
+
+ Returns:
+ The value of the metric.
+ """
return self[item]
def __getitem__(self, item):
+ """
+ Returns the value of a metric.
+
+ Args:
+ item: The name of the metric.
+
+ Returns:
+ The value of the metric.
+ """
return self._rec[item]
def __setitem__(self, key, value):
+ """
+ Sets the value of a metric.
+
+ Args:
+ key: The name of the metric.
+ value: The value to set the metric to.
+ """
value = to_ndarray(value)
stg = self._avg.get(key, None)
@@ -83,66 +153,156 @@ def __setitem__(self, key, value):
self._stage = 'default'
def __repr__(self):
+ """
+ Returns a string representation of the meter.
+
+ Returns:
+ A string representation of the meter.
+ """
return ' | '.join([f'{k}: {v}' for k, v in self._rec.items()])
def __iter__(self):
+ """
+ Returns an iterator over the metric names.
+
+ Returns:
+ An iterator over the metric names.
+ """
yield from self.keys()
@property
def sum(self):
+ """
+ Sets the aggregation method to 'sum'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'sum'
return self
@property
def mean(self):
+ """
+ Sets the aggregation method to 'mean'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'mean'
return self
@property
def last(self):
+ """
+ Sets the aggregation method to 'last'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'last'
return self
@property
def max(self):
+ """
+ Sets the aggregation method to 'max'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'max'
return self
@property
def min(self):
+ """
+ Sets the aggregation method to 'min'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'min'
return self
@property
def smean(self):
+ """
+ Sets the aggregation method to 'smean'.
+
+ Returns:
+ The meter itself.
+ """
self._stage = 'smean'
return self
def update(self, dic: Mapping) -> 'Meter':
+ """
+ Updates the meter with the given dictionary of metrics.
+
+ Args:
+ dic (Mapping): A dictionary containing the metrics and their values.
+
+ Returns:
+ The meter itself.
+ """
for k, v in dic.items():
self[str(k)] = v
return self
def serialize(self) -> OrderedDict:
+ """
+ Returns a dictionary representation of the meter.
+
+ Returns:
+ An ordered dictionary containing the metrics and their string values.
+ """
res = OrderedDict()
for k, v in self.items():
res[k] = f'{v}'
return res
def items(self) -> ItemsView:
+ """
+ Returns a view object containing the (metric, value) pairs.
+
+ Returns:
+ A view object containing the (metric, value) pairs.
+ """
return self._rec.items()
def keys(self):
+ """
+ Returns a view object containing the metric names.
+
+ Returns:
+ A view object containing the metric names.
+ """
return self._rec.keys()
@staticmethod
def from_dict(dic: Mapping):
+ """
+ Returns a new meter with the given dictionary of metrics.
+
+ Args:
+ dic (Mapping): A dictionary containing the metrics and their values.
+
+ Returns:
+ A new meter with the given metrics and values.
+ """
m = Meter()
for k, v in dic.items():
m[k] = v
return m
def scalar_items(self) -> Iterator[Tuple[str, Number]]:
+ """
+ Returns an iterator over the (metric, value) pairs with scalar values.
+
+ Returns:
+ An iterator over the (metric, value) pairs with scalar values.
+ """
for k, v in self.items():
nd = to_ndarray(v)
if is_scalar(nd):
@@ -150,10 +310,46 @@ def scalar_items(self) -> Iterator[Tuple[str, Number]]:
class ReduceItem:
+ """Class that reduces a sequence of values to a single value according to a given method.
+
+ Attributes:
+ SLIDE_WINDOW_SIZE (int): The size of the sliding window used for averaging (default: 100).
+ EXP_WEIGHT (float): The exponential weight used for computing the sliding window offset (default: 0.75).
+
+ Args:
+ item (optional): The initial value (default: None).
+ gb_method (optional): The reduction method (default: None). Can be one of {'slide', 'mean', 'sum', 'max', 'min', 'last'}.
+
+ Raises:
+ AssertionError: If the reduction method is 'min' or 'max' and the input is not a scalar.
+
+ Methods:
+ __repr__(): Returns a string representation of the current reduction value.
+ __str__(): Returns the same string representation as __repr__().
+ update(item): Updates the reduction value with a new item.
+ res: Returns the current reduction value.
+
+ Examples:
+ >>> r = ReduceItem(gb_method='mean')
+ >>> r.update(2)
+ >>> r.update(3)
+ >>> r.res
+ 2.5
+ """
SLIDE_WINDOW_SIZE = 100
EXP_WEIGHT = 0.75
def __init__(self, item=None, gb_method=None):
+ """
+ Initializes a new ReduceItem instance.
+
+ Args:
+ item (optional): The initial value (default: None).
+ gb_method (optional): The reduction method (default: None). Can be one of {'slide', 'mean', 'sum', 'max', 'min', 'last'}.
+
+ Raises:
+ AssertionError: If the reduction method is 'min' or 'max' and the input is not a scalar.
+ """
self.gb_method = gb_method # groupby method
self.acc = []
if item is not None:
@@ -161,11 +357,11 @@ def __init__(self, item=None, gb_method=None):
self.c = len(self.acc)
self.cur = item
if gb_method == 'max':
- self.last = -1e12
+ self._last = -1e12
elif gb_method == 'min':
- self.last = 1e12
+ self._last = 1e12
else:
- self.last = 0
+ self._last = 0
self._res = self.last
@@ -180,10 +376,10 @@ def __init__(self, item=None, gb_method=None):
def __repr__(self):
"""
- simpler but more time-comsuming method could be some math function, not in if-else branch, like
- prec = max(min(8, int(np.ceil(np.log10((1 / (self.offset + 1e-10)))))), 1)
- fmt_str = f'{{:.{prec}f}}'
- return fmt_str.format(res)
+ Returns a string representation of the current reduction value.
+
+ Returns:
+ str: The string representation of the current reduction value.
"""
res = self.res
if self.isscalar:
@@ -207,6 +403,7 @@ def __repr__(self):
__str__ = __repr__
def update(self, item):
+ """Updates the reduction value with a new item."""
self.cur = item
item = detach(item)
@@ -218,7 +415,7 @@ def update(self, item):
self.acc.append(item)
if len(self.acc) > ReduceItem.SLIDE_WINDOW_SIZE:
self.acc.pop(0)
- self.last = self.cur
+ self._last = self.cur
elif avg in {'mean', 'sum'}:
if len(self.acc) == 0:
self.acc.append(0)
@@ -229,10 +426,20 @@ def update(self, item):
elif avg == 'min':
self._res = min(self.cur, self._res)
- self.last = item
+ self._last = item
+
+ @property
+ def last(self):
+ return self._last
@property
def res(self):
+ """
+ Returns the current reduction value.
+
+ Returns:
+ float: The current reduction value.
+ """
avg = self.gb_method
if avg == 'slide':
return np.mean(self.acc)
diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py
index 745877cd..c3384f9a 100644
--- a/src/lumo/core/params.py
+++ b/src/lumo/core/params.py
@@ -2,17 +2,17 @@
import os.path
import sys
import textwrap
+from pathlib import Path
from pprint import pformat
-from typing import Any, List, NewType
+from typing import Any, List, NewType, MutableMapping
import fire
from joblib import hash
-from pathlib import Path
from omegaconf import DictConfig, OmegaConf, DictKeyType
from omegaconf._utils import _ensure_container
-from .attr import safe_update_dict, set_item_iterative
-from .raises import BoundCheckError, NewParamWarning
+# from .attr import safe_update_dict, set_item_iterative
+from .raises import BoundCheckError
# arange_param = namedtuple('arange_param', ['default', 'left', 'right'], defaults=[None, float('-inf'), float('inf')])
# choice_param = namedtuple('choice_param', ['default', 'choices'], defaults=[None, []])
@@ -21,6 +21,14 @@
class Arange:
+ """A class representing a range of numeric values with a default, left and right boundaries.
+
+ Attributes:
+ default: The default value of the range. Defaults to None.
+ left: The left boundary of the range. Defaults to positive infinity.
+ right: The right boundary of the range. Defaults to positive infinity.
+ """
+
def __init__(self, default=None, left=float('inf'), right=float('inf')):
self.default = default
self.left = left
@@ -31,6 +39,13 @@ def __repr__(self):
class Choices:
+ """A class representing a list of choices with a default value.
+
+ Attributes:
+ default: The default value of the list. Defaults to None.
+ choices: A list of values representing the available choices. Defaults to an empty list.
+ """
+
def __init__(self, default=None, choices=None):
if choices is None:
choices = []
@@ -41,47 +56,35 @@ def __repr__(self):
return f"Choice: [{self.default}], {self.choices}"
-def _get_item(dic, keys: List[str]):
- if len(keys) == 1:
- return DictConfig.__getitem__(dic, keys[0])
- else:
- nex = DictConfig.__getitem__(dic, keys[0])
- if isinstance(nex, (dict, DictConfig)):
- return _get_item(nex, keys[1:])
- else:
- raise KeyError(keys)
-
+def _safe_repr(values: Any) -> str:
+ """Return a formatted string representation of the input values.
-def _set_item(dic, keys: List[str], value):
- if len(keys) == 1:
- if isinstance(value, dict):
- value = dic(value)
- DictConfig.__setitem__(dic, keys[0], value)
- else:
- try:
- nex = _get_item(dic, keys[:1])
- except KeyError:
- nex = DictConfig({})
- DictConfig.__setitem__(dic, keys[0], nex)
- _set_item(nex, keys[1:], value)
+ Args:
+ values: Any type of input values to be formatted.
+ Returns:
+ A string representation of the input values, formatted using `pprint`.
-def _safe_repr(values: Any) -> str:
+ Raises:
+ None.
+ """
return pformat(values)
def _padding_mod(st: str, offset=7, mod=4):
- """
- 123 \\
- 1 \\
- 12312341 \\
- 1231
+ """Pads a string with spaces to a length that is a multiple of a given modulus.
+
Args:
- strs:
- mod:
+ st: The input string to pad.
+ offset: An integer specifying the minimum length of the output string. If the length of the input string is
+ less than this value, spaces will be added to the end of the string to make it the desired length.
+ mod: An integer specifying the modulus. The length of the output string will be a multiple of this value.
Returns:
-
+ A string that is a multiple of the given modulus and has a length of at least `offset`. If the length of the
+ input string is less than `offset`, the output string will be padded with spaces to achieve the minimum length.
+ If the length of the input string is already a multiple of the given modulus, the output string will have the
+ same length as the input string.
"""
size = len(st)
if size < offset:
@@ -95,13 +98,15 @@ def _padding_mod(st: str, offset=7, mod=4):
def safe_param_repr(values: List[tuple], level=1) -> str:
"""
+ Returns a string representation of a list of tuples containing parameter names and their values.
+ The resulting string can be safely included in a function signature or call, as it correctly formats the parameters.
Args:
- values:
- level:
+ values: A list of tuples containing parameter names and their values.
+ level: An integer representing the level of indentation.
Returns:
-
+ A string representation of the input parameters, formatted with correct indentation and line breaks.
"""
res = [(f"{k}={_safe_repr(v)},", anno) for k, v, anno in values]
@@ -112,16 +117,30 @@ def safe_param_repr(values: List[tuple], level=1) -> str:
class BaseParams(DictConfig):
+ """
+ A dictionary-like configuration object that supports parameter constraint validation.
+ """
+
def __init__(self):
+ """
+ Initializes a new instance of the BaseParams class.
+ """
super().__init__({}, flags={'no_deepcopy_set_nodes': True})
- # self._set_flag('no_deepcopy_set_nodes', True)
self.__dict__["_prop"] = {}
def __setattr__(self, key: str, value: Any) -> None:
- if key != '_prop':
- # if isinstance(value, BaseParams):
- # self._prop.setdefault('key_type', {})[key] = type(value)
+ """
+ Sets an attribute value for the specified key.
+
+ Args:
+ key (str): The key of the attribute.
+ value (Any): The value of the attribute.
+
+ Raises:
+ BoundCheckError: If the specified value is not within the specified bounds or choices.
+ """
+ if key != '_prop':
if isinstance(value, (Arange, Choices)):
res = self._prop.get('constrain', {})
res[key] = value
@@ -134,10 +153,18 @@ def __setattr__(self, key: str, value: Any) -> None:
super().__setattr__(key, value)
def __setitem__(self, key: DictKeyType, value: Any) -> None:
- if key != '_prop':
- # if isinstance(value, BaseParams):
- # self._prop.setdefault('key_type', {})[key] = type(value)
+ """
+ Sets a dictionary item value for the specified key.
+
+ Args:
+ key (DictKeyType): The key of the item.
+ value (Any): The value of the item.
+ Raises:
+ BoundCheckError: If the specified value is not within the specified bounds or choices.
+
+ """
+ if key != '_prop':
if isinstance(value, (Arange, Choices)):
self._prop.setdefault('constrain', {})[key] = value
value = value.default
@@ -148,13 +175,31 @@ def __setitem__(self, key: DictKeyType, value: Any) -> None:
super().__setitem__(key, value)
def __getattr__(self, key: str) -> Any:
+ """
+ Gets an attribute value for the specified key.
+
+ Args:
+ key (str): The key of the attribute.
+
+ Returns:
+ Any: The value of the attribute.
+
+ """
res = super().__getattr__(key)
- # key_type = self._prop.setdefault('key_type', {}).get(key, None)
- # if key_type is not None:
- # res = key_type.from_kwargs(**res)
return res
def _check(self, name, value):
+ """
+ Checks if the specified parameter value is within the specified bounds or choices.
+
+ Args:
+ name (str): The name of the parameter.
+ value (Any): The value of the parameter.
+
+ Raises:
+ BoundCheckError: If the specified value is not within the specified bounds or choices.
+
+ """
bound = self._prop['constrain'][name]
if isinstance(bound, Arange) and not (bound.left <= value and value <= bound.right):
raise BoundCheckError(
@@ -163,10 +208,29 @@ def _check(self, name, value):
raise BoundCheckError(f"value of param '{name}' should in values {bound.choices}, but got {value}")
def __getitem__(self, key: DictKeyType) -> Any:
+ """
+ Gets a dictionary item value for the specified key.
+
+ Args:
+ key (DictKeyType): The key of the item.
+
+ Returns:
+ Any: The value of the item.
+
+ """
return super().__getitem__(key)
def __repr__(self):
+ """
+ Returns a string representation of the BaseParams object.
+
+ Returns:
+ str: A string representation of the BaseParams object.
+
+ """
+
def _arg_to_str(k, v):
+ """to str"""
res = self._prop.get('constrain', {}).get(k, None)
if res is not None:
return f'{res}, {type(v).__name__}'
@@ -185,6 +249,13 @@ def _arg_to_str(k, v):
return "{}.Space".format(self.__class__.__name__) + '(\n' + args_str + '\n)'
def copy(self):
+ """
+ Returns a copy of the BaseParams object.
+
+ Returns:
+ BaseParams: A copy of the BaseParams object.
+
+ """
copied = self.__class__()
copied.from_dict(super(BaseParams, self).copy())
return copied
@@ -236,73 +307,155 @@ def choice(self, *choices) -> Choices:
"""
return Choices(choices[0], choices)
- def safe_update(self, dic, assert_type=True):
- self.update(
- safe_update_dict(self.to_dict(), dic, assert_type=assert_type)
- )
+ def safe_update(self, dic, assert_type=False):
+ """
+ Merge `dict` object into the config object, safely updating the values.
+
+ Args:
+ dic: `dict` object to update
+ assert_type: If True, enforce that the type of values in `dic` matches the current config.
+
+ Returns:
+ None
+ """
+ safe_update_dict(self, dic, assert_type=assert_type)
+
+ def from_dict(self, dic: MutableMapping):
+ """
+ Update the config object from a dictionary.
- def from_dict(self, dic: dict):
+ Args:
+ dic: `dict` object to update
+
+ Returns:
+ updated `self` object
+ """
self.safe_update(dic)
return self
def from_kwargs(self, **kwargs):
+ """
+ Update the config object from keyword arguments.
+
+ Args:
+ **kwargs: key-value pairs to update in the config object
+
+ Returns:
+ updated `self` object
+ """
return self.from_dict(kwargs)
def from_json(self, file):
- self.safe_update(json.loads(Path(file).read_text()), assert_type=True)
+ """
+ Update the config object from a JSON file.
+
+ Args:
+ file: path to the JSON file
+
+ Returns:
+ updated `self` object
+ """
+ self.safe_update(json.loads(Path(file).read_text()), assert_type=False)
return self
def from_yaml(self, file):
- self.safe_update(dict(OmegaConf.load(file)), assert_type=True)
+ """
+ Update the config object from a YAML file.
+
+ Args:
+ file: path to the YAML file
+
+ Returns:
+ updated `self` object
+ """
+ self.safe_update(dict(OmegaConf.load(file)), assert_type=False)
return self
def from_args(self, argv: list = None):
+ """
+ Update the config object from command line arguments.
+
+ Args:
+ argv: list of command line arguments (default: None)
+
+ Returns:
+ updated `self` object
+ """
if argv is None:
argv = sys.argv
- def func(**kwargs):
+ def func(*args, **kwargs):
+ """function to process arg list"""
if 'help' in kwargs:
print(self)
exit()
return
-
config = kwargs.get('config')
if config is None:
config = kwargs.get('c')
if config is not None and isinstance(config, str) and os.path.exists(config):
- self.from_yaml(config)
+ if config.endswith('yaml') or config.endswith('yml'):
+ self.from_yaml(config)
+ elif config.endswith('json'):
+ self.from_json(config)
- dic = {}
+ dic = BaseParams()
for k, v in kwargs.items():
set_item_iterative(dic, k.split('.'), v)
- self.safe_update(dic)
+ self.safe_update(dic.to_dict())
fire.Fire(func, command=argv)
return self
def from_hydra(self, config_path, config_name):
+ """load from hydra config mode"""
import hydra
hydra.compose()
@hydra.main(config_path=config_path, config_name=config_name)
def inner(cfg):
+ """inner function"""
return cfg
self.update(inner())
return self
def to_dict(self):
+ """
+ Convert this configuration to a dictionary.
+
+ Returns:
+ dict: The configuration as a dictionary.
+ """
cfg = _ensure_container(self)
container = OmegaConf.to_container(cfg, resolve=False, enum_to_str=True)
return container
def to_json(self, file=None):
+ """
+ Convert this configuration to a JSON string.
+
+ Args:
+ file (str or Path, optional): If specified, the JSON string will be written to a file at the given path.
+
+ Returns:
+ str or None: The JSON string, or None if file is specified.
+ """
info = json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
if file is None:
return info
return Path(file).write_text(info, encoding='utf-8')
def to_yaml(self, file=None):
+ """
+ Convert this configuration to a YAML string.
+
+ Args:
+ file (str or Path, optional): If specified, the YAML string will be written to a file at the given path.
+
+ Returns:
+ str or None: The YAML string, or None if file is specified.
+ """
info = OmegaConf.to_yaml(self)
if file is None:
return info
@@ -310,24 +463,173 @@ def to_yaml(self, file=None):
@classmethod
def Space(cls, **kwargs):
+ """
+ Create a configuration object from keyword arguments.
+
+ Args:
+ **kwargs: The configuration values.
+
+ Returns:
+ BaseParams: The new configuration object.
+ """
return cls().from_dict(kwargs)
def __hash__(self):
+ """
+ Calculate the hash value of this configuration.
+
+ Returns:
+ int: The hash value.
+ """
return int(self.hash(), 16)
def hash(self) -> str:
+ """
+ Calculate the hash value of this configuration.
+
+ Returns:
+ str: The hash value.
+ """
return hash(self.to_dict())
def iparams(self):
+ """Initialization method, mostly used in Trainer"""
pass
@classmethod
def init_from_kwargs(cls, **kwargs):
+ """
+ Create a configuration object from keyword arguments.
+
+ Args:
+ **kwargs: The configuration values.
+
+ Returns:
+ BaseParams: The new configuration object.
+ """
return cls().from_dict(kwargs)
class Params(BaseParams):
+ """A class representing parameters"""
pass
ParamsType = NewType('ParamsType', Params)
+
+
+def safe_update_dict(src: BaseParams, kwargs: dict, assert_type=False):
+ """
+ Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner.
+
+ This function iterates over the items in the kwargs dictionary and updates the corresponding items in the
+ source dictionary, making sure that the types of the values being updated match the types of the values
+ already in the source dictionary.
+
+ Args:
+ src (dict): The dictionary to update.
+ kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary.
+ assert_type (bool): A flag indicating whether to check that the types of the values being updated match
+ the types of the values already in the source dictionary. Defaults to True.
+
+ Returns:
+ dict: The updated source dictionary.
+ """
+ for ks, v in walk_dict(kwargs):
+ try:
+ old_v = get_item_iterative(src, ks)
+ if old_v is None or isinstance(old_v, type(v)) or not assert_type:
+ set_item_iterative(src, ks, v)
+ else:
+ raise TypeError(ks, type(old_v), type(v))
+ except KeyError:
+ set_item_iterative(src, ks, v)
+ return src
+
+
+def walk_dict(dic: MutableMapping, root=None):
+ """
+ Recursively walks through a dictionary and yields keys and values in a flattened format.
+
+ Args:
+ - dic (dict): The dictionary to be walked through.
+ - root (list): The root keys to be used in the resulting flattened format. Defaults to None.
+
+ Yields:
+ - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary.
+
+ Example:
+ ```python
+ d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
+ for k, v in walk_dict(d):
+ print(k, v)
+ # Output:
+ # (['a', 'b'], 1)
+ # (['a', 'c', 'd'], 2)
+ # (['e'], 3)
+ ```
+ """
+ if root is None:
+ root = []
+ for k, v in dic.items():
+ if isinstance(v, dict):
+ yield from walk_dict(v, [*root, *k.split('.')])
+ else:
+ yield [*root, *k.split('.')], v
+
+
+def set_item_iterative(dic: BaseParams, keys: List[str], value):
+ """
+ Sets the value of a nested key in a dictionary using an iterative approach.
+
+ Args:
+ dic (dict): The dictionary to update.
+ keys (List[str]): A list of keys representing the path to the nested key in the dictionary.
+ value: The value to set for the nested key.
+
+ Raises:
+ ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary.
+
+ """
+ if len(keys) == 1:
+ if isinstance(value, MutableMapping):
+ for ks, v in walk_dict(value):
+ set_item_iterative(dic, [*keys, *ks], v)
+ else:
+ dic.__setitem__(keys[0], value)
+ else:
+ try:
+ nex = dic.__getitem__(keys[0])
+ if not isinstance(nex, MutableMapping):
+ raise ValueError(keys[0], nex)
+ # dict.__setitem__(dic, keys[0], nex)
+ except KeyError:
+ nex = BaseParams()
+ dic.__setitem__(keys[0], nex)
+
+ set_item_iterative(nex, keys[1:], value)
+
+
+def get_item_iterative(dic: MutableMapping, keys: List[str]):
+ """
+ Gets the value of a nested key in a dictionary using an iterative approach.
+
+ Args:
+ dic (dict): The dictionary to retrieve the value from.
+ keys (List[str]): A list of keys representing the path to the nested key in the dictionary.
+
+ Raises:
+ KeyError: If the nested key does not exist in the dictionary.
+
+ Returns:
+ The value of the nested key in the dictionary.
+
+ """
+ if len(keys) == 1:
+ return dic.__getitem__(keys[0])
+ else:
+ nex = dic.__getitem__(keys[0])
+ if isinstance(nex, MutableMapping):
+ return get_item_iterative(nex, keys[1:])
+ else:
+ raise KeyError(keys)
diff --git a/src/lumo/core/raises.py b/src/lumo/core/raises.py
index 01945f6f..65759d2b 100644
--- a/src/lumo/core/raises.py
+++ b/src/lumo/core/raises.py
@@ -1,4 +1,2 @@
-class BoundCheckError(BaseException): pass
-
-
-class NewParamWarning(Warning): pass
+class BoundCheckError(BaseException):
+ """Exception raised when a bound check fails."""
diff --git a/src/lumo/core/record.py b/src/lumo/core/record.py
index 7d1cda2a..d17c8183 100644
--- a/src/lumo/core/record.py
+++ b/src/lumo/core/record.py
@@ -2,7 +2,7 @@
from numbers import Number
from . import Attr
-from .meter import Meter
+from .meter import Meter, ReduceItem
import torch
import numpy as np
from typing import NewType, Mapping, Union, Sequence, Dict
@@ -20,7 +20,6 @@
def wrap_result(metric: MetricType) -> Meter:
"""
Wrap any form of metric data into Meter form.
-
"""
if isinstance(metric, (Meter,)):
return metric
@@ -34,17 +33,37 @@ def wrap_result(metric: MetricType) -> Meter:
class Record:
- def __init__(self, window_size=500, **kwargs):
+ """Record class for storing and computing metrics over a window of steps.
+
+ Attributes:
+ **kwargs: Additional properties to add to the record object.
+
+ Properties:
+ stage (str): The stage of the training process, such as 'train' or 'eval'.
+
+ Methods:
+ avg(): Computes the average value of the recorded metrics.
+ agg(): Computes the aggregated value of the recorded metrics.
+ __str__(): Returns a string representation of the recorded metrics.
+ tostr(): Alias for __str__().
+ record(metric, global_step=None): Records a metric for the current step.
+ clear(): Clears the recorded metrics.
+ flush(): Clears the cache of recorded metrics.
+ """
+
+ def __init__(self, **kwargs):
self._prop = {}
self._prop.update(kwargs)
self._cache = []
- self._agg = OrderedDict() # type:Dict[str,AggItem]
+ self._agg = OrderedDict() # type:Dict[str,ReduceItem]
def avg(self) -> Attr:
+ """DEPRECATED: Computes the average value of the recorded metrics."""
warnings.warn('avg will be deprecated in later version, use agg() instead.')
return self.agg()
def agg(self) -> Attr:
+ """Computes the aggregated value of the recorded metrics."""
res = Attr()
for k, v in self._agg.items():
res[k] = v.res
@@ -52,9 +71,11 @@ def agg(self) -> Attr:
@property
def stage(self):
+ """Gets the stage of the training process, such as 'train' or 'eval'."""
return self._prop['stage']
def __str__(self):
+ """Returns a string representation of the recorded metrics."""
res = self.agg()
rep = []
for k, v in res.items():
@@ -66,60 +87,28 @@ def __str__(self):
return ', '.join(rep)
def tostr(self):
+ """Alias for __str__()."""
return str(self)
def record(self, metric, global_step=None):
+ """Records a metric for the current step."""
meter = wrap_result(metric)
agg = meter._avg
-
+ # print(agg)
+ # print(meter)
for k, v in meter.items():
stg = agg.get(k, 'last')
item = self._agg.get(k, None)
if item is None:
- item = AggItem(stg)
+ item = ReduceItem(gb_method=stg)
item.update(v)
self._agg[k] = item
def clear(self):
+ """Clears the recorded metrics."""
self._agg.clear()
self._cache.clear()
def flush(self):
+ """Clears the cache of recorded metrics."""
self._cache.clear()
-
-
-class AggItem:
- def __init__(self, stg):
- self.stg = stg
- self._last = 0
- self.acc = 0
- self.c = 0
-
- @property
- def res(self):
- if self.stg == 'mean':
- return self.acc / self.c
-
- if self.stg in {'min', 'max', 'last'}:
- return self.acc
-
- if self.stg == 'sum':
- return self.acc
-
- return self.acc
-
- @property
- def last(self):
- return self._last
-
- def update(self, val):
- if self.stg == 'last':
- self.acc = val
- elif self.stg == 'min':
- self.acc = min(self.acc, val)
- elif self.stg == 'max':
- self.acc = max(self.acc, val)
- elif self.stg in {'mean', 'sum'}:
- self.acc += val
- self.c += 1
- self._last = val
diff --git a/src/lumo/core/record_backend/abc.py b/src/lumo/core/record_backend/abc.py
index 1a46d630..b1a81b12 100644
--- a/src/lumo/core/record_backend/abc.py
+++ b/src/lumo/core/record_backend/abc.py
@@ -3,6 +3,15 @@
class RecordBackend(ABC):
+ """
+ Defines an abstract base class for recording and logging data.
+
+ This module provides an abstract base class, RecordBackend,
+ that defines the interface for recording and logging data.
+
+ Subclasses of RecordBackend must implement the log(), log_image(), log_audio(), and log_video() methods.
+ """
+
def __init__(self, location, *args, **kwargs):
self.location = location
@@ -10,11 +19,43 @@ def log(self, data: Dict[str, Any],
step: Optional[int] = None,
commit: Optional[bool] = None,
sync: Optional[bool] = None):
+ """
+ Logs data to the backend.
+
+ Args:
+ data (Dict[str, Any]): A dictionary containing the data to be logged.
+ step (Optional[int]): The step number associated with the data.
+ commit (Optional[bool]): Whether to commit the data to storage.
+ sync (Optional[bool]): Whether to synchronize the data across multiple devices.
+ """
raise NotImplementedError()
def log_image(self, image_array, caption=None):
+ """
+ Logs an image to the backend.
+
+ Args:
+ image_array (ndarray): A NumPy array representing the image to be logged.
+ caption (Optional[str]): A caption describing the image.
+ """
raise NotImplementedError()
+
def log_audio(self, image_array, caption=None):
+ """
+ Logs an audio clip to the backend.
+
+ Args:
+ audio_array (ndarray): A NumPy array representing the audio clip to be logged.
+ caption (Optional[str]): A caption describing the audio clip.
+ """
raise NotImplementedError()
+
def log_image(self, image_array, caption=None):
+ """
+ Logs a video clip to the backend.
+
+ Args:
+ video_array (ndarray): A NumPy array representing the video clip to be logged.
+ caption (Optional[str]): A caption describing the video clip.
+ """
raise NotImplementedError()
diff --git a/src/lumo/core/tree.py b/src/lumo/core/tree.py
index 93255d4e..5131349c 100644
--- a/src/lumo/core/tree.py
+++ b/src/lumo/core/tree.py
@@ -1,14 +1,24 @@
-"""
-
-"""
-import queue
from collections import defaultdict
class tree(dict):
- """Implementation of perl's autovivification feature."""
+ """Implements Perl's autovivification feature.
+
+ This class extends Python's built-in dict class to allow for automatic creation of nested dictionaries
+ on access of non-existent keys. It accomplishes this by overriding the __getitem__ method to recursively
+ create new nested trees on access of a non-existent key. Additionally, the walk method is provided to
+ allow for iterating over all keys and values in the tree.
+ """
def __getitem__(self, item):
+ """Override __getitem__ to automatically create new trees on access of non-existent keys.
+
+ Args:
+ item: The key being accessed.
+
+ Returns:
+ The value of the key.
+ """
try:
return dict.__getitem__(self, item)
except KeyError:
@@ -16,6 +26,14 @@ def __getitem__(self, item):
return value
def walk(self):
+ """Iterate over all keys and values in the tree.
+
+ This method yields a tuple containing the current key and value, and recursively calls itself
+ on any nested tree values.
+
+ Yields:
+ A tuple containing the current key and value.
+ """
for k, v in self.items():
yield k, v
if isinstance(v, tree):
@@ -24,6 +42,16 @@ def walk(self):
class Node:
+ """Represents a node in the Forest.
+
+ Attributes:
+ HEAD (int): A class variable indicating the head node.
+ MID (int): A class variable indicating a mid node.
+ TAIL (int): A class variable indicating the tail node.
+ value (any): The value held by the node.
+ link (list): A list of the node's adjacent nodes.
+ stage (int): The position of the node in the linked list.
+ """
HEAD = 0
MID = 1
TAIL = 2
@@ -35,25 +63,49 @@ def __init__(self):
@property
def is_head(self):
+ """Returns True if the node is a head node, else False."""
return self.stage == self.HEAD
@property
def is_mid(self):
+ """Returns True if the node is a mid node, else False."""
return self.stage == self.MID
@property
def is_tail(self):
+ """Returns True if the node is a tail node, else False."""
return self.stage == self.TAIL
def set_stage(self, stage):
+ """Sets the stage attribute of the node to the specified value.
+
+ Args:
+ stage (int): An integer indicating the position of the node in the linked list.
+
+ Returns:
+ Node: The node with the updated stage attribute.
+ """
self.stage = stage
return self
def set_value(self, val):
+ """Sets the value attribute of the node to the specified value.
+
+ Args:
+ val (any): The value to set the node's value attribute to.
+
+ Returns:
+ Node: The node with the updated value attribute.
+ """
self.value = val
return self
def add_link(self, y):
+ """Adds the specified node to the list of adjacent nodes.
+
+ Args:
+ y (Node): The node to add to the list of adjacent nodes.
+ """
self.link.append(y)
def __repr__(self):
@@ -61,20 +113,66 @@ def __repr__(self):
class Forest:
+ """
+ Represents a directed acyclic graph (DAG) where nodes are categorized as head, mid or tail.
+
+ Attributes:
+ dic (defaultdict): A dictionary to store the nodes of the graph.
+ order (list): A list to maintain the order of the nodes.
+ tail (set): A set to store the tail nodes of the graph.
+
+ """
+
def __init__(self):
self.dic = defaultdict(Node)
self.order = []
self.tail = set()
def add_head(self, x, val=None):
+ """
+ Adds a new head node to the graph with the given value.
+
+ Args:
+ x: The node to be added.
+ val: The value associated with the node. Defaults to None.
+
+ Returns:
+ The updated Forest object.
+
+ """
self.dic[x].set_value(val).set_stage(Node.HEAD)
self.order.append(x)
return self
def check_node_type(self, x):
+ """
+ Checks if a node is already present in the graph.
+
+ Args:
+ x: The node to be checked.
+
+ Returns:
+ True if the node is already present, False otherwise.
+
+ """
return x in self.dic
def add_link(self, x, y, y_val=None):
+ """
+ Adds a new mid node to the graph and links it with an existing head node.
+
+ Args:
+ x: The head node to be linked with the new mid node.
+ y: The new mid node to be added.
+ y_val: The value associated with the new mid node. Defaults to None.
+
+ Returns:
+ The updated Forest object.
+
+ Raises:
+ AssertionError: If the head node is not already present in the graph or the mid node is already present.
+
+ """
assert x in self.dic, f'x must already existed in graph, has {self.order}, got {x}'
assert y not in self.dic, f'y must be a new node in graph, has {self.order}, got {y}'
self.dic[x].add_link(y)
@@ -83,6 +181,21 @@ def add_link(self, x, y, y_val=None):
return self
def add_tail(self, x, y, y_val=None):
+ """
+ Adds a new tail node to the graph and links it with an existing head or mid node.
+
+ Args:
+ x: The node to be linked with the new tail node.
+ y: The new tail node to be added.
+ y_val: The value associated with the new tail node. Defaults to None.
+
+ Returns:
+ The updated Forest object.
+
+ Raises:
+ AssertionError: If the head or mid node is not already present in the graph or the tail node is already present.
+
+ """
assert x in self.dic, f'x must already existed in graph, has {self.order}, got {x}'
assert y not in self.dic, f'y must be a new node in graph, has {self.order}, got {y}'
self.dic[x].add_link(y)
@@ -92,6 +205,13 @@ def add_tail(self, x, y, y_val=None):
return self
def __iter__(self):
+ """
+ Returns an iterator that iterates over the nodes in the graph in a Breadth-First-Search (BFS) order.
+
+ Returns:
+ An iterator that yields a tuple with the node and its corresponding value.
+
+ """
stack = []
mem = set()
diff --git a/src/lumo/data/__main__.py b/src/lumo/data/__main__.py
deleted file mode 100644
index 1ab0e4aa..00000000
--- a/src/lumo/data/__main__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""
-
-cd {{repo_root}}
-python -m lumo.data supervised
-
-python -m lumo.data selfsupervised
-
-python -m lumo.data semisupervised
-from datasets.objdect
-
-python -m lumo.data https://github.com/example/dataset objdect
-
-
-from datasets.objdect
-"""
diff --git a/src/lumo/data/builder.py b/src/lumo/data/builder.py
index 32f5722a..1522bc6a 100644
--- a/src/lumo/data/builder.py
+++ b/src/lumo/data/builder.py
@@ -1,6 +1,7 @@
import copy
import warnings
from functools import partial
+from itertools import cycle
from pprint import pformat
from copy import copy as builtin_copy
from typing import Callable, NewType, Dict, Any, Iterable, Sequence
@@ -129,27 +130,49 @@ def __iter__(self):
for key, outkeys in self._outs.items():
if key == '::idx::':
- warnings.warn(f'iter does not support idx, will skip {outkeys}')
- continue
+ source = enumerate(cycle(range(1)))
+ else:
+ source = self._data[key]
+ # for outkey in outkeys:
+ ipt_transform = self._transforms.get(key, None)
- source = self._data[key]
- for outkey in outkeys:
- self._iter_cache[outkey] = iter(source)
+ self._iter_cache[key] = iter(source), ipt_transform
return self
def __next__(self):
if len(self._iter_cache) == 0:
raise StopIteration()
try:
- outputs = {k: next(v) for k, v in self._iter_cache.items()}
- if self.mode != 'zip':
+ inputs_sample = {}
+ for (key, (ipt_iter, ipt_transform)) in (self._iter_cache.items()):
+ if key == '::idx::':
+ ipt = next(ipt_iter)[0]
+ else:
+ ipt = next(ipt_iter)
+ if ipt_transform is not None:
+ ipt = ipt_transform(ipt)
+ inputs_sample[key] = ipt
+
+ outputs = {}
+ for key, outkeys in self._outs.items():
+ for outkey in outkeys:
+ opt = inputs_sample[key]
+ opt_transform = self._transforms.get(f'::{outkey}', None)
+ if opt_transform is not None:
+ opt = opt_transform(opt)
+ outputs[outkey] = opt
+
+ if self.mode == 'chain':
outputs = [outputs[outkey] for outkey in self._outkeys]
+
return outputs
except StopIteration as e:
self._iter_cache.clear()
raise e
def __getitem__(self, index):
+ if not self.sized:
+ raise TypeError('Source is not sizable. Use iterative method.')
index = self.map_index(index)
@@ -189,6 +212,12 @@ def __len__(self):
return self._prop['__clen__']
def _update_len(self):
+ """
+ Update the length of the dataset.
+
+ Returns:
+ res (int): The length of the dataset after updating.
+ """
if not self.sized:
self._prop['__clen__'] = None
return
@@ -210,10 +239,22 @@ def _update_len(self):
@property
def inputs(self):
+ """
+ Property getter for the input sources.
+
+ Returns:
+ self._data (dict): The input data.
+ """
return self._data
@property
def outputs(self):
+ """
+ Property getter for the output data.
+
+ Returns:
+ mapping (dict): A mapping of output keys to their corresponding sources.
+ """
mapping = {}
for key, outkeys in self._outs.items():
if key == '::idx::':
@@ -227,25 +268,61 @@ def outputs(self):
@property
def mode(self):
+ """
+ Property getter for the mode of the dataset.
+
+ Returns:
+ self._prop.get('mode', 'zip'): The mode of the dataset.
+ """
return self._prop.get('mode', 'zip')
@property
def iterable(self):
+ """
+ Property getter for the iterability of the dataset.
+
+ Returns:
+ self._prop.get('iterable', False): Whether the dataset is iterable.
+ """
return self._prop.get('iterable', False)
@property
def subindices(self):
+ """
+ Property getter for the sub-indices of the dataset.
+
+ Returns:
+ self._prop.get('subindices', None): The sub-indices of the dataset.
+ """
return self._prop.get('subindices', None)
@property
def pseudo_length(self) -> int:
+ """
+ Property getter for the pseudo-length of the dataset.
+
+ Returns:
+ self._prop.get('pseudo_length', None): The pseudo-length of the dataset.
+ """
return self._prop.get('pseudo_length', None)
@property
def pseudo_repeat(self) -> int:
+ """
+ Property getter for the pseudo-repeat of the dataset.
+
+ Returns:
+ self._prop.get('pseudo_repeat', None): The pseudo-repeat of the dataset.
+ """
return self._prop.get('pseudo_repeat', None)
def copy(self):
+ """
+ Create a copy of the dataset builder.
+
+ Returns:
+ builder (DatasetBuilder): a copy of the dataset builder.
+ """
builder = DatasetBuilder()
builder._prop = builtin_copy(self._prop)
builder._idx_keys = builtin_copy(self._idx_keys)
@@ -257,6 +334,16 @@ def copy(self):
return builder
def subset(self, indices: Sequence[int], copy=False):
+ """
+ Create a subset of the dataset builder by selecting indices.
+
+ Args:
+ indices (Sequence[int]): a sequence of indices to select.
+ copy (bool): whether to create a copy of the dataset builder or modify it in place.
+
+ Returns:
+ builder (DatasetBuilder): a new dataset builder containing only the selected indices.
+ """
if copy:
builder = self.copy()
else:
@@ -266,6 +353,15 @@ def subset(self, indices: Sequence[int], copy=False):
return builder
def scale_to_size(self, size: int):
+ """
+ Scale the dataset builder to a certain size by repeating or randomly truncating the data.
+
+ Args:
+ size (int): the size to scale the dataset builder to.
+
+ Returns:
+ self (DatasetBuilder): the scaled dataset builder.
+ """
assert isinstance(size, int)
assert 'pseudo_repeat' not in self._prop
assert 'pseudo_length' not in self._prop
@@ -275,6 +371,15 @@ def scale_to_size(self, size: int):
return self
def repeat(self, multiple: int):
+ """
+ Repeat the dataset builder multiple times to increase its size.
+
+ Args:
+ multiple (int): the number of times to repeat the dataset builder.
+
+ Returns:
+ self (DatasetBuilder): the repeated dataset builder.
+ """
assert isinstance(multiple, int)
assert 'pseudo_length' not in self._prop
assert 'pseudo_repeat' not in self._prop
@@ -285,12 +390,13 @@ def repeat(self, multiple: int):
def map_index(self, index):
"""
- Map the raw index to the final index for source data.
+ Map the raw index to the final index for the source data.
+
Args:
- index:
+ index: the raw index to map.
Returns:
-
+ index (int): the final index for the source data.
"""
if self.pseudo_length is not None or self.pseudo_repeat is not None:
if self.subindices is not None:
@@ -305,21 +411,46 @@ def map_index(self, index):
@property
def sized(self):
+ """
+ Check if the dataset builder is sized.
+
+ Returns:
+ bool: True if the dataset builder is sized, False otherwise.
+ """
return self._prop.get('sized', False)
def chain(self):
+ """
+ Set the mode of the dataset builder to chain.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the chain mode set.
+ """
self._prop['mode'] = 'chain'
return self
def item(self):
+ """
+ Set the mode of the dataset builder to item.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the item mode set.
+ """
self._prop['mode'] = 'item'
return self
def zip(self):
+ """
+ Set the mode of the dataset builder to zip.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the zip mode set.
+ """
self._prop['mode'] = 'zip'
return self
def _check_source(self, name, source):
+ """Check the input source for validity and compatibility with the dataset builder."""
# source is sized can be itered
# source can be itered not meant it is sizable.
if self.subindices is not None:
@@ -351,6 +482,15 @@ def _check_source(self, name, source):
raise TypeError(f'Source {name} must be an iterable or sized object, but got {type(source)}.')
def add_idx(self, name):
+ """
+ Add an index pseudo source to the dataset builder.
+
+ Args:
+ name (str): the name of the index.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the index added.
+ """
outkeys = self._outs.setdefault(f"::idx::", [])
assert name not in self._outkeys, f'Output key {name} duplicated.'
outkeys.append(name)
@@ -359,12 +499,15 @@ def add_idx(self, name):
def add_input(self, name: str, source, transform: SingleValueTransform = None):
"""
- Register a input source with the transform (if provided).
+ Add an input source to the dataset builder.
+
Args:
- name: source name
- source: source, should be a sized object.
- transform:
+ name (str): the name of the input source.
+ source: the input source to add.
+ transform (SingleValueTransform): the transform to apply to the input source.
+ Returns:
+ self (DatasetBuilder): the dataset builder with the input source added.
Notes:
Iterable object without `__len__` method currently are not well-tested. Be careful to use them in DatasetBuilder.
@@ -373,24 +516,36 @@ def add_input(self, name: str, source, transform: SingleValueTransform = None):
assert name not in self._data, f'Source name {name} duplicated.'
self._check_source(name, source)
self._data[name] = source
- self._transforms[name] = transform
+ self.set_input_transform(name, transform)
+ # self._transforms[name] = transform
return self
def add_input_transform(self, name: str, transform: SingleValueTransform = None):
+ """
+ Add a transform to an existing input source.
+
+ Args:
+ name (str): the name of the input source to add the transform to.
+ transform (SingleValueTransform): the transform to add.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the transform added.
+ """
assert name in self._data, f'Source {name} should be added.'
warnings.warn('`add` may cause confusion, use set_input_transform ')
return self.set_input_transform(name, transform)
def add_output(self, name: str, outkey: str, transform: SingleValueTransform = None):
"""
- Add a data flow from inputs[name] to outputs[outkey] with the transform (if provided).
+ Add an output flow from input source to the dataset builder.
+
Args:
- name: source name of inputs
- outkey: output name of output
- transform: a callable function
+ name (str): the name of the input source for the output.
+ outkey (str): the name of the output.
+ transform (SingleValueTransform): the transform to apply to the output.
Returns:
-
+ self (DatasetBuilder): the dataset builder with the output added.
"""
assert name in self._data, f'Must have data source {name} first.'
@@ -399,42 +554,62 @@ def add_output(self, name: str, outkey: str, transform: SingleValueTransform = N
outkeys.append(outkey)
self._outkeys.append(outkey)
- self._transforms[f'::{outkey}'] = transform
+ self.set_output_transform(outkey, transform)
+ # self._transforms[f'::{outkey}'] = transform
return self
def add_output_transform(self, outkey: str, transform: SingleValueTransform = None):
"""
- Add or **replace** transform of the output name.
+ Add a transform to an existing output.
+
Args:
- outkey: output name.
- transform: a callable function
+ outkey (str): the name of the output to add the transform to.
+ transform (SingleValueTransform): the transform to add.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the transform added.
"""
assert outkey in self._outkeys, f'Output key {outkey} should be added.'
warnings.warn('add may cause confusion, use set_output_transform ')
return self.set_output_transform(outkey, transform)
def add_global_transform(self, transform: DictTransform):
+ """
+ Add a global transform to the dataset builder.
+
+ Args:
+ transform (DictTransform): the global transform to apply to the dataset.
+
+ Returns:
+ self (DatasetBuilder): the dataset builder with the global transform added.
+ """
self._transforms['::global::'] = transform
return self
def set_input_transform(self, name, transform: SingleValueTransform = None):
"""
- Add or **replace** transform of the input source {name}.
+ Set the transform for an input source.
+
Args:
- name: source name.
- transform: a callable function
+ name (str): the name of the input source to set the transform for.
+ transform (SingleValueTransform): the transform to set.
+ Returns:
+ self (DatasetBuilder): the dataset builder with the transform set.
"""
self._transforms[name] = transform
return self
def set_output_transform(self, outkey, transform: SingleValueTransform = None):
"""
- Add or **replace** transform of the output {name}.
+ Set the transform for an output.
+
Args:
- outkey: output name.
- transform: a callable function
+ outkey (str): the name of the output to set the transform for.
+ transform (SingleValueTransform): the transform to set.
+ Returns:
+ self (DatasetBuilder): the dataset builder with the transform set.
"""
self._transforms[f'::{outkey}'] = transform
return self
@@ -449,4 +624,6 @@ def __getattribute__(self, item):
return res
def get_source(self, name):
- return self._data[name]
+ """Get the input source with the given name."""
+ warnings.warn('Use inputs[name] instead.')
+ return self.inputs[name]
diff --git a/src/lumo/data/collate.py b/src/lumo/data/collate.py
index 56eb985e..3dd301c7 100644
--- a/src/lumo/data/collate.py
+++ b/src/lumo/data/collate.py
@@ -1,7 +1,12 @@
"""
+A module that provides classes and functions to act as the collect_fn in dataloader.
+
+Classes:
+ CollateBase: A base class for implementing collate functions.
+ IgnoreNoneCollate: A collate function that ignores None samples.
"""
-from typing import Any, Mapping, Sequence
+from typing import Mapping, Sequence
from lumo.core.params import ParamsType
import numpy as np
import torch
@@ -9,9 +14,37 @@
class CollateBase:
+ """
+ A base class for implementing collate functions.
+
+ Args:
+ params (ParamsType, optional): Parameters for the collate function. Defaults to None.
+
+ Attributes:
+ _collate_fn (callable): A function that combines a list of samples into a batch.
+ params (ParamsType): Parameters for the collate function.
+
+ Methods:
+ from_collate: Creates an instance of the class from a collate function.
+ __call__: Applies the collate function to a list of samples.
+ before_collate: A hook function called before the collate function is applied.
+ raw_collate: Applies the collate function to a list of samples without calling the `before_collate` and `after_collate` hook functions.
+ collate: Alias for `raw_collate`.
+ after_collate: A hook function called after the collate function is applied.
+ """
@classmethod
def from_collate(cls, collate_fn, params: ParamsType = None):
+ """
+ Creates an instance of the class from a collate function.
+
+ Args:
+ collate_fn (callable): A function that combines a list of samples into a batch.
+ params (ParamsType, optional): Parameters for the collate function. Defaults to None.
+
+ Returns:
+ CollateBase: An instance of the class.
+ """
instance = cls(params)
instance._collate_fn = collate_fn
return instance
@@ -22,25 +55,72 @@ def __init__(self, params: ParamsType = None) -> None:
self.params = params
def __call__(self, *args, **kwargs):
+ """
+ Applies the collate function to a list of samples.
+
+ Args:
+ *args: A list of samples.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ The batch of samples.
+ """
res = self.before_collate(*args, **kwargs)
res = self._collate_fn(res)
res = self.after_collate(res)
return res
def before_collate(self, sample_list):
+ """
+ A hook function called before the collate function is applied.
+
+ Args:
+ sample_list (Sequence): A list of samples.
+
+ Returns:
+ Sequence: The modified list of samples.
+ """
return sample_list
def raw_collate(self, sample_list):
+ """
+ Applies the collate function to a list of samples without calling the `before_collate` and `after_collate` hook functions.
+
+ Args:
+ sample_list (Sequence): A list of samples.
+
+ Returns:
+ The batch of samples.
+ """
return self._collate_fn(sample_list)
def collate(self, sample_list):
+ """
+ Alias for `raw_collate`.
+
+ Args:
+ sample_list (Sequence): A list of samples.
+
+ Returns:
+ The batch of samples.
+ """
return self._collate_fn(sample_list)
def after_collate(self, batch):
+ """
+ A hook function called after the collate function is applied.
+
+ Args:
+ batch (Any): The batch of samples.
+
+ Returns:
+ Any: The modified batch of samples.
+ """
return batch
class IgnoreNoneCollate(CollateBase):
+ """A collate function that ignores `None` samples."""
def _filter_none(self, item):
if item is None:
@@ -52,11 +132,24 @@ def _filter_none(self, item):
return True
def before_collate(self, sample_list):
+ """ before collate"""
return list(filter(self._filter_none, sample_list))
def numpy_collate(batch):
- r"""Puts each data field into a tensor with outer dimension batch size"""
+ """Collate function for numpy arrays.
+
+ Args:
+ batch (list): A list of numpy arrays or other python objects.
+
+ Returns:
+ numpy.ndarray or dict or list: Returns a collated numpy array, a dictionary of collated numpy arrays,
+ or a list of collated numpy arrays depending on the type of input elements.
+
+ Raises:
+ RuntimeError: If the elements in batch do not have consistent size.
+
+ """
elem = batch[0]
elem_type = type(elem)
diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py
index e858b1ac..eee32145 100644
--- a/src/lumo/data/datamodule.py
+++ b/src/lumo/data/datamodule.py
@@ -1,3 +1,6 @@
+"""
+A module to manage dataloaders for different stages of training.
+"""
from typing import NoReturn, Union, overload, Optional
from torch.utils.data import DataLoader
@@ -8,47 +11,80 @@
class DataModule:
+ """
+ Used in Trainer to easily access DataLoaders for different stage(train/test/eval/others).
+ """
+
def __init__(self, params: ParamsType = None):
self._prop = {}
self.params = params
@property
def prop(self):
+ """
+ Get the dictionary of registered dataloaders.
+
+ Returns:
+ dict: the dictionary of registered dataloaders.
+ """
return self._prop
@staticmethod
def _parse_dataset(loader):
+ """
+ Parse a dataset from a dataloader.
+
+ Args:
+ loader (Union[DataLoader, DataLoaderSide]): a dataloader or a dataloader side.
+
+ Returns:
+ Dataset: a dataset if `loader` is a `DataLoader` or a `DataLoaderSide`, None otherwise.
+
+ """
if isinstance(loader, DataLoader):
return loader.dataset
elif isinstance(loader, DataLoaderSide):
return loader.dataset
- return None
+ raise NotImplementedError(type(loader))
@property
def train_dataset(self):
+ """
+ Get the train dataset.
+
+ Returns:
+ Dataset: the train dataset.
+
+ """
return self._parse_dataset(self.train_dataloader)
@property
def test_dataset(self):
+ """Get the test dataset."""
return self._parse_dataset(self.test_dataloader)
@property
def val_dataset(self):
+ """Get the validation dataset."""
return self._parse_dataset(self.val_dataloader)
@property
def train_dataloader(self) -> Optional[DataLoaderType]:
+ """Get the train dataloader."""
return self.get_loader_with_stage(TrainStage.train)
@property
def test_dataloader(self) -> Optional[DataLoaderType]:
+ """Get the test dataloader."""
return self.get_loader_with_stage(TrainStage.test)
@property
def val_dataloader(self) -> Union[NoReturn, DataLoaderType]:
+ """Get the val dataloader."""
return self.get_loader_with_stage(TrainStage.val)
def get_loader_with_stage(self, stage: TrainStage) -> DataLoaderType:
+ """Get the dataloader for a given stage."""
res = self._prop.get(stage.value, None)
if res is None:
self.idataloader(self.params, stage)
@@ -59,15 +95,46 @@ def __getitem__(self, key):
return self.prop.get(key, None)
@overload
- def regist_dataloader(self, train=None, test=None, val=None):
- ...
+ def regist_dataloader(self, train=None, test=None, val=None, **kwargs):
+ """
+ Registers the given dataloaders under the given keys.
+
+ Args:
+ train: A DataLoaderType object for the train set.
+ test: A DataLoaderType object for the test set.
+ val: A DataLoaderType object for the validation set.
+ **kwargs: A DataLoaderType object for other stage
+ """
def regist_dataloader(self, **kwargs: dict):
+ """
+ Registers the given dataloaders under the given keys.
+
+ Args:
+ train: A DataLoaderType object for the train set.
+ test: A DataLoaderType object for the test set.
+ val: A DataLoaderType object for the validation set.
+ **kwargs: A DataLoaderType object for other stage
+ """
for k, v in kwargs.items():
self.prop[k] = v
def regist_dataloader_with_stage(self, stage: TrainStage, dl: DataLoaderType):
- self.prop[stage.value] = dl
+ """
+ Registers the given dataloader under the given TrainStage.
+
+ Args:
+ stage: A TrainStage object.
+ dl: A DataLoaderType object.
+ """
+ self.regist_dataloader(**{stage.value: dl})
def idataloader(self, params: ParamsType = None, stage: TrainStage = None):
+ """
+ Interface function to implement in a child class to set up data loading.
+
+ Args:
+ params: A ParamsType object containing data loading parameters.
+ stage: A TrainStage object indicating which stage to set up data loading for.
+ """
pass
diff --git a/src/lumo/data/loader.py b/src/lumo/data/loader.py
index e178b0fa..db11d35f 100644
--- a/src/lumo/data/loader.py
+++ b/src/lumo/data/loader.py
@@ -6,10 +6,24 @@
class LumoDataLoader(DataLoader):
+ """This module defines the LumoDataLoader class that inherits from the DataLoader class."""
pass
def summarize_loader(loader: DataLoader):
+ """
+ Summarize the DataLoader object and return a formatted string representation.
+
+ Args:
+ loader: A DataLoader object.
+
+ Returns:
+ A formatted string representation of the DataLoader object.
+
+ Raises:
+ ValueError: If the input argument is not a DataLoader object.
+
+ """
if isinstance(loader, DataLoaderSide):
inner = pformat({f"{k}(cycle={loader._cycle[k]})": summarize_loader(v) for k, v in loader._loaders.items()})
return f"DataLoaderSide({inner})"
@@ -40,7 +54,36 @@ def summarize_loader(loader: DataLoader):
class DataLoaderSide:
"""
- `DataLoaderSide` is used when different DataLoader with different batch_size are feeded at the same time.
+ A utility class for loading data from different DataLoaders with different batch sizes at the same time.
+
+ Example usage:
+ loader = DataLoaderSide()
+ loader.add('train', train_loader, cycle=True)
+ loader.add('val', val_loader)
+ loader.zip()
+ for batch in loader:
+ # process batch
+
+ Attributes:
+ _loaders: An ordered dictionary that maps the name of the DataLoader to the corresponding DataLoader instance.
+ _cycle: An ordered dictionary that maps the name of the DataLoader to a boolean indicating whether the DataLoader should be cycled.
+ _state: A string that indicates the current state of the DataLoaderSide instance. The possible values are 'zip' and 'chain'.
+
+ Methods:
+ dataset(): Returns a dictionary that maps the name of the DataLoader to its corresponding dataset.
+ source(): Returns the _loaders dictionary.
+ add(name, loader, cycle=False): Adds a DataLoader instance to the _loaders dictionary.
+ name is the name of the DataLoader.
+ loader is the DataLoader instance to be added.
+ cycle is a boolean indicating whether the DataLoader should be cycled. Defaults to False.
+ copy(): Returns a new DataLoaderSide instance with the same _loaders, _cycle, and _state attributes as the original.
+ zip(): Sets the _state attribute to 'zip', which means the batches are zipped together.
+ if _state is 'zip', the batches are returned as an ordered dictionary.
+ chain(): Sets the _state attribute to 'chain', which means the batches are concatenated.
+ if _state is 'chain', the batches are returned as a list.
+ len(): Returns the minimum length of all the DataLoaders that do not have the cycle flag set to True.
+ iter(): Returns an iterator that generates batches from the DataLoaders in the _loaders dictionary.
+
"""
def __init__(self):
@@ -50,18 +93,29 @@ def __init__(self):
@property
def dataset(self):
+ """Returns a dictionary that maps the name of the DataLoader to its corresponding dataset."""
return {k: v.dataset for k, v in self.source.items()}
@property
def source(self):
+ """Returns the _loaders dictionary."""
return self._loaders
def add(self, name, loader: DataLoader, cycle=False):
+ """
+ Adds a DataLoader instance to the _loaders dictionary.
+ Args:
+ name (str): The name of the DataLoader.
+ loader (DataLoader): The DataLoader instance to be added.
+ cycle (bool): A boolean indicating whether the DataLoader should be cycled. Defaults to False.
+
+ """
self._loaders[name] = loader
self._cycle[name] = cycle
return self
def copy(self):
+ """Returns a new DataLoaderSide instance with the same _loaders, _cycle, and _state attributes as the original."""
loader = DataLoaderSide()
loader._loaders = self._loaders
loader._cycle = self._cycle
@@ -69,18 +123,26 @@ def copy(self):
return loader
def zip(self):
+ """Sets the _state attribute to 'zip', which means the batches are zipped together.
+ If _state is 'zip', the batches are returned as an ordered dictionary."""
self._state = 'zip'
return self
def chain(self):
+ """
+ Sets the _state attribute to 'chain', which means the batches are concatenated.
+ If _state is 'chain', the batches are returned as a list.
+ """
self._state = 'chain'
return self
def __len__(self):
+ """Returns the minimum length of all the DataLoaders that do not have the cycle flag set to True."""
valid_keys = [k for k, cycle in self._cycle.items() if not cycle]
return min([len(self._loaders[k]) for k in valid_keys])
def __iter__(self):
+ """Returns an iterator that generates batches from the DataLoaders in the _loaders dictionary."""
iters = {k: iter(v)
for k, v in self._loaders.items()}
stop = None
diff --git a/src/lumo/decorators/clsmethod.py b/src/lumo/decorators/clsmethod.py
index 70b4682e..b63c8019 100644
--- a/src/lumo/decorators/clsmethod.py
+++ b/src/lumo/decorators/clsmethod.py
@@ -5,13 +5,18 @@
def clswrap(callable: T) -> T:
"""
- 带类型提示的 staticmethod()
+ A decorator for creating a staticmethod with type hints.
+
Args:
- callable:
+ callable: The function to be wrapped with a staticmethod.
Returns:
+ A new staticmethod that calls the original function.
+
Notes:
- 必须在类中用
+ This decorator should be used on a class method. It creates a new staticmethod that calls the original function,
+ allowing it to be called without an instance of the class. The `callable` argument should have type hints
+ for the parameters and return type. The resulting staticmethod will also have the same type hints.
"""
@staticmethod
diff --git a/src/lumo/decorators/debounce.py b/src/lumo/decorators/debounce.py
new file mode 100644
index 00000000..ed951428
--- /dev/null
+++ b/src/lumo/decorators/debounce.py
@@ -0,0 +1,38 @@
+import time
+from functools import wraps
+
+
+def debounce(wait):
+ """
+ debounce
+
+ Args:
+ wait (float): seconds
+
+ Returns:
+ function:
+
+ Examples:
+ @debounce(5)
+ def my_func():
+ print("my_func called")
+
+ while True:
+ my_func()
+ time.sleep(1)
+ """
+
+ def decorator(func):
+ last_time = 0
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ nonlocal last_time
+ now = time.time()
+ if now - last_time > wait:
+ last_time = now
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
diff --git a/src/lumo/decorators/lazy_required.py b/src/lumo/decorators/lazy_required.py
index 8814ffb3..f6c62984 100644
--- a/src/lumo/decorators/lazy_required.py
+++ b/src/lumo/decorators/lazy_required.py
@@ -5,6 +5,16 @@
def is_lib_available(lib):
+ """
+ Check if a library is available to be imported.
+
+ Args:
+ lib (str): The name of the library to check for.
+
+ Returns:
+ bool: True if the library is available, False otherwise.
+
+ """
if lib in _lib_memory:
return True
res = imputil.find_spec(lib)
@@ -16,6 +26,10 @@ def is_lib_available(lib):
def torch_required(func):
+ """
+ Wrap a function to raise an ImportError if PyTorch is not available.
+ """
+
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
@@ -28,6 +42,10 @@ def wrapper(*args, **kwargs):
def lib_required(lib_name):
+ """
+ Wrap a function to raise an ImportError if a required library is not available.
+ """
+
def outer(func):
@wraps(func)
def wrapper(*args, **kwargs):
diff --git a/src/lumo/decorators/process.py b/src/lumo/decorators/process.py
index 9c1375b3..5a488127 100644
--- a/src/lumo/decorators/process.py
+++ b/src/lumo/decorators/process.py
@@ -4,6 +4,20 @@
def call_on_main_process_wrap(func) -> Callable:
+ """
+ Wrap a function to only execute on the main process.
+
+ If the current process is the main process, it calls the original func with the passed arguments and returns the result.
+ If it is not the main process, it does nothing and returns None.
+
+ Args:
+ func (callable): The function to wrap.
+
+ Returns:
+ callable: A wrapped function that only executes on the main process.
+
+ """
+
def inner(*args, **kwargs):
if is_main():
return func(*args, **kwargs)
diff --git a/src/lumo/decorators/regist.py b/src/lumo/decorators/regist.py
index 7c4bfe8c..6829f969 100644
--- a/src/lumo/decorators/regist.py
+++ b/src/lumo/decorators/regist.py
@@ -4,6 +4,25 @@
def regist_func_to(val: Union[Dict[str, Callable], List[Callable]], name_=None):
+ """
+ A decorator function that registers a function to a dictionary or list.
+
+ Args:
+ val: A dictionary or list to register the function to.
+ name_: The key or index of the function in the dictionary or list. If not provided, it defaults to the name
+ of the function.
+
+ Returns:
+ The decorated function.
+
+ Notes:
+ This decorator function is used to register a function to a dictionary or list. If the `val` argument is a
+ dictionary, the function will be added to the dictionary with a key equal to `name_` or the function name if
+ `name_` is not provided. If the `val` argument is a list, the function will be appended to the list. The
+ registered function can then be retrieved and called later. The `name_` argument is optional, but if it is
+ provided it should be a string. The `val` argument should be either a dictionary or a list of functions.
+ """
+
def wrap(func):
if name_ is None:
name = func.__name__
@@ -19,27 +38,98 @@ def wrap(func):
return wrap
-class Register():
- def __init__(self, name):
+class Register:
+ """
+ A class for registering functions.
+
+ Args:
+ name: The name of the register.
+
+ Attributes:
+ name: The name of the register.
+ source: An ordered dictionary that holds the registered functions.
+
+ Methods:
+ __str__(): Returns a string representation of the register.
+ __repr__(): Returns a string representation of the register.
+ __getitem__(item): Gets a function from the register by name.
+ __call__(wrapped, name): A decorator function that adds a function to the register.
+ regist(name): A method that returns a partial function of __call__ with the register's name.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize the register.
+
+ Args:
+ name: The name of the register.
+ """
self.name = name
self.source = OrderedDict()
- def __str__(self):
+ def __str__(self) -> str:
+ """
+ Return a string representation of the register.
+
+ Returns:
+ A string representation of the register.
+ """
inner = str([(k, v) for k, v in self.source.items()])
return f"Register({self.name}{inner})"
- def __repr__(self):
+ def __repr__(self) -> str:
+ """
+ Return a string representation of the register.
+
+ Returns:
+ A string representation of the register.
+ """
return self.__str__()
- def __getitem__(self, item):
+ def __getitem__(self, item: str):
+ """
+ Get a function from the register by name.
+
+ Args:
+ item: The name of the function.
+
+ Returns:
+ The function with the given name, or None if the function is not in the register.
+ """
return self.source.get(item, None)
- def __call__(self, wrapped, name=None):
+ def __call__(self, wrapped: callable, name: str = None) -> callable:
+ """
+ Add a function to the register.
+
+ Args:
+ wrapped: The function to be added to the register.
+ name: The name of the function in the register. If not provided, the function's name will be used.
+
+ Returns:
+ The original function, unchanged.
+
+ Raises:
+ AssertionError: If the `name` argument is not provided.
+ """
if name is None:
name = wrapped.__name__
assert name is not None
self.source[name] = wrapped
return wrapped
- def regist(self, name=None):
+ def regist(self, name: str = None):
+ """
+ Returns a partial function of __call__ with the register's name.
+
+ Args:
+ name: The name of the function in the register. If not provided, the function's name will be used.
+
+ Returns:
+ A partial function of __call__ with the register's name.
+
+ Notes:
+ This method is used to create a decorator that will add a function to the register. The `name` argument
+ is optional, but if it is provided it should be a string.
+ """
return partial(self, name=name)
diff --git a/src/lumo/exp/README.md b/src/lumo/exp/README.md
deleted file mode 100644
index e57d6fe2..00000000
--- a/src/lumo/exp/README.md
+++ /dev/null
@@ -1,14 +0,0 @@
-- 记录实验,负责实验的组织、目录管理、元数据记录
-
-```
-from lumo_experiment import Experiment
-
-Experiment('name')
-```
-
-Experiment 的意义:
-
-- 为每个实验分配唯一存储空间(通过 exp_name/test_name )
-- 为每个实验保留回溯可能(通过 git)
-- 为每个实验进行标注,方便(可能存在的)检索
-- 其他状态动态记录
\ No newline at end of file
diff --git a/src/lumo/exp/__main__.py b/src/lumo/exp/__main__.py
deleted file mode 100644
index dd16f581..00000000
--- a/src/lumo/exp/__main__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""
-列出某次实验的全部信息
-"""
diff --git a/src/lumo/exp/base.py b/src/lumo/exp/base.py
index 257d925c..fbfb056f 100644
--- a/src/lumo/exp/base.py
+++ b/src/lumo/exp/base.py
@@ -1,7 +1,12 @@
from lumo.proc import glob
-class ExpHook:
+class BaseExpHook:
+ """A base class of hook for experiments that can be registered with an experiment.
+
+ Please Use Exphook in exp.exphook for better typehint.
+
+ """
name = None # type: str
configs = {}
@@ -13,21 +18,81 @@ def __new__(cls):
@property
def config_name(self):
+ """Get the configuration name for the hook.
+
+ Returns:
+ A string representing the configuration name for the hook.
+
+ """
return f'HOOK_{self.name.upper()}'
@property
def config_string(self):
+ """Get the configuration string for the hook.
+
+ Returns:
+ A string representing the configuration string for the hook.
+
+ """
+
return ', '.join(f'{k}={glob.get(k, v)}' for k, v in self.configs.items())
- def regist(self, exp): self.exp = exp
+ def regist(self, exp):
+ """Register the hook with an experiment.
+
+ Args:
+ exp: The experiment to register the hook with.
+
+ """
+ self.exp = exp
+
+ def on_start(self, exp, *args, **kwargs):
+ """Execute when the experiment starts.
- def on_start(self, exp, *args, **kwargs): pass
+ Args:
+ exp: The experiment that has started.
+ *args: Any additional arguments passed to the method.
+ **kwargs: Any additional keyword arguments passed to the method.
- def on_end(self, exp, end_code=0, *args, **kwargs): pass
+ """
- def on_progress(self, exp, step, *args, **kwargs): pass
+ def on_end(self, exp, end_code=0, *args, **kwargs):
+ """Execute when the experiment ends.
- def on_newpath(self, exp, *args, **kwargs): pass
+ Args:
+ exp: The experiment that has ended.
+ end_code (int): The exit code for the experiment.
+ *args: Any additional arguments passed to the method.
+ **kwargs: Any additional keyword arguments passed to the method.
+
+ """
+
+ def on_progress(self, exp, step, *args, **kwargs):
+ """Execute when the experiment makes progress.
+
+ Args:
+ exp: The experiment that is making progress.
+ step: The current step of the experiment.
+ *args: Any additional arguments passed to the method.
+ **kwargs: Any additional keyword arguments passed to the method.
+
+ """
+
+ def on_newpath(self, exp, *args, **kwargs):
+ """Execute when the experiment creates a new path.
+
+ Args:
+ exp: The experiment that is creating a new path.
+ *args: Any additional arguments passed to the method.
+ **kwargs: Any additional keyword arguments passed to the method.
+
+ """
def __repr__(self):
+ """Return a string representation of the hook.
+
+ Returns:
+ A string representation of the hook.
+
+ """
return f"Hook(name={self.__class__.name}, switch={self.config_name}, {self.config_string})"
diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py
index 6738f690..358e9c3a 100644
--- a/src/lumo/exp/experiment.py
+++ b/src/lumo/exp/experiment.py
@@ -1,34 +1,51 @@
+"""
+Experiment 负责的内容
+ - 管理路径 PathHelper
+ - 记录信息 InfoIO 和度量 Metric
+ - 快照 snap 和复现 rerun
+"""
import os
import random
import sys
import time
import traceback
-from pathlib import Path
-from typing import Union
-
+from typing import Any, List
+from functools import wraps
from lumo.decorators.process import call_on_main_process_wrap
from lumo.proc import glob
from lumo.proc.dist import is_dist, is_main, local_rank
-from lumo.proc.path import blobroot, libhome, progressroot
+from lumo.proc.path import blobroot, cache_dir, libhome
from lumo.proc.path import exproot, local_dir
from lumo.utils import safe_io as io
-from lumo.utils.fmt import can_be_filename
+from lumo.utils.fmt import can_be_filename, strftime
from lumo.utils.logger import Logger
-from .base import ExpHook
+from .base import BaseExpHook
from ..proc.pid import pid_hash, runtime_pid_obj
-
-
-def checkdir(path: Union[Path, str]):
- if isinstance(path, str):
- os.makedirs(path, exist_ok=True)
- elif isinstance(path, Path):
- path.mkdir(parents=True, exist_ok=True)
- return path
+from .metric import Metric
class Experiment:
"""
- (by default), the directory structure is as following:
+ Represents an experiment and manages its directory structure. An experiment consists of multiple tests, each of which
+ has its own directory to store information related to that test.
+
+
+ -
+ - progress
+ -
+ - {test-1}.hb
+ - {test-1}.pid
+
+ -
+ -
+ - (info_dir)
+
+ -
+ -
+ - (blob_dir)
+
+
+ (By default), the directory structure is as following:
.lumo (libroot)
- progress
- ".{pid}" -> hash
@@ -37,12 +54,19 @@ class Experiment:
- experiments # (exp_root) record information (e.g., .log, params files, etc.)
- {experiment-name-1}
- {test-1}
- # infomation
- {
- progress
- pid_hash (for lumo.client monitor)
- other_info: git, file, version_lock, etc.
- }
+ metric_board.sqlite (metrics in training ,powered by dbrecord)
+ metric.pkl (final metrics)
+ params.yaml (hyper parameter)
+ note.md (manually note)
+ l.0.2303062216.log (log file)
+ text/
+ exception.str
+ ...
+ info/
+ *.json
+ git.json
+ execute.json
+ lock.json
- {test-2}
- {experiment-name-2}
- {test-1}
@@ -54,44 +78,139 @@ class Experiment:
- {experiment-name-2}
- {test-1}
- {test-2}
- - metric # (metric_root) record metrics (by trainer.database)
- - {experiment-name-1}
- - {test-1}
- - {test-2}
- - {experiment-name-2}
- - {test-1}
- - {test-2}
+ {lumo.cache_dir}
+ - progress # (metric_root) record metrics (by trainer.database)
+ - hb # trigger for test information update
+ {experiment-name}
+ - {test-1} -> timestamp
+ - {test-2} -> timestamp
+ - pid # link to running process
+ - {pid1} -> test_root
+ - {pid2} -> test_root
"""
- def __init__(self, exp_name: str, root=None):
+ ENV_TEST_NAME_KEY = 'LUMO_EXP_TEST_NAME'
+
+ def __init__(self, exp_name: str, test_name=None, paths=None):
+ """
+ Initializes a new instance of the Experiment class.
+
+ Args:
+ exp_name (str): The name of the experiment. This should be a legal filename and contain only letters or
+ underscores.
+
+ Raises:
+ ValueError: If the experiment name is not a legal filename.
+ """
if not can_be_filename(exp_name):
raise ValueError(f'Experiment name should be a ligal filename(bettor only contain letter or underline),'
f'but got {exp_name}.')
self._prop = {}
self._prop['exp_name'] = exp_name
+ if test_name is None:
+ test_name = os.environ.get(Experiment.ENV_TEST_NAME_KEY, None)
+ self._prop['test_name'] = test_name
+ if paths is None:
+ paths = {}
+ self._prop['paths'] = paths
+
self._hooks = {}
- if root is None:
- root = libhome()
- self._root = Path(os.path.abspath(root))
+ self._metric = None
+
+ # wrap
+ self.dump_string = self._trigger_change(self.dump_string)
+ self.dump_note = self._trigger_change(self.dump_note)
+ self.dump_info = self._trigger_change(self.dump_info)
+
self.add_exit_hook(self.end)
self.logger = Logger()
+ def __getitem__(self, item):
+ """
+ Gets a property of the experiment.
+
+ Args:
+ item (str): The name of the property to get.
+
+ Returns:
+ Any: The value of the property.
+ """
+ return self._prop[item]
+
+ def __setitem__(self, key, value):
+ """
+ Sets a property of the experiment.
+
+ Args:
+ key (str): The name of the property to set.
+ value (Any): The value to set the property to.
+ """
+ self._prop[key] = value
+
+ def __enter__(self):
+ """
+ Starts the experiment when the Experiment object is used as a context manager using the 'with' statement.
+
+ Returns:
+ Experiment: The Experiment object.
+ """
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Ends the experiment when the 'with' statement block exits.
+
+ Args:
+ exc_type (type): The type of the exception that occurred, if any.
+ exc_val (Exception): The exception object that was raised, if any.
+ exc_tb (traceback): The traceback object for the exception, if any.
+ """
+ extra = {}
+ if exc_type is not None:
+ exc_type = traceback.format_exception_only(exc_type, exc_val)[-1].strip()
+ extra['exc_type'] = exc_type
+ extra['end_info'] = str(exc_type)
+ extra['end_code'] = 1
+ extra['exc_stack'] = "".join(traceback.format_exception(exc_type, exc_val, exc_tb))
+ self.end(**extra)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Experiment object.
+
+ Returns:
+ str: A string representation of the Experiment object.
+ """
+ return f'{self.exp_name}->({self.test_name})'
+
+ def __str__(self):
+ """
+ Returns a string representation of the Experiment object.
+
+ Returns:
+ str: A string representation of the Experiment object.
+ """
+ return self.__repr__()
+
+ def _repr_html_(self):
+ """Return a html representation for a particular DataFrame."""
+ return self.__repr__()
+
@property
def exp_name(self):
+ """
+ str: Gets the name of the experiment.
+ """
return self._prop['exp_name']
- @property
- def _test_name(self):
- return self._prop.get('test_name', None)
-
- @_test_name.setter
- def _test_name(self, value):
- self._prop['test_name'] = value
-
@property
def test_name_with_dist(self):
"""
+ str: Gets the name of the current test with the local rank number
+ appended to it if running in distributed mode.
+
Create different test_name for each process in multiprocess training.
Returns: in main process, will just return test_name itself,
in subprocess, "{test_name}.{local_rank()}"
@@ -107,214 +226,406 @@ def test_name_with_dist(self):
@property
def test_name(self):
- """Assign unique space(directory) for this test"""
+ """
+ str: Gets the name of the current test being run.
+
+ If the test name is not set, generates a new unique name and sets it.
+ """
if self._test_name is None:
if is_dist(): # if train in distribute mode, subprocess will wait a few seconds to wait main process.
- flag_fn = f'.{os.getppid()}'
+ flag_fn = os.path.join(self.cache_root, 'dist', f'.{os.getppid()}')
+ os.makedirs(os.path.dirname(flag_fn), exist_ok=True)
+
if is_main():
- self._test_name = self._create_test_name()
- fn = self.exp_file(flag_fn)
- with open(fn, 'w') as w:
+ self._test_name = self._create_test_name(self.exp_dir)
+
+ with open(flag_fn, 'w') as w:
w.write(self._test_name)
else:
time.sleep(random.randint(2, 4))
- fn = self.exp_file(flag_fn)
- if os.path.exists(fn):
- with open(fn, 'r') as r:
+ if os.path.exists(flag_fn):
+ with open(flag_fn, 'r') as r:
self._test_name = r.readline().strip()
else:
- self._test_name = self._create_test_name()
+ self._test_name = self._create_test_name(self.exp_dir)
return self._test_name
- def _create_test_name(self):
+ @property
+ def _test_name(self):
"""
- [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t
+ str: Gets the name of the current test being run.
"""
- from lumo.proc.date import timehash
- from ..utils.fmt import strftime
- fs = os.listdir(self.exp_root)
- date_str = strftime('%y%m%d')
- fs = [i for i in fs if i.startswith(date_str)]
- _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t"
- return _test_name
+ return self._prop.get('test_name')
+
+ @_test_name.setter
+ def _test_name(self, value):
+ """
+ Sets the name of the current test being run.
+
+ Args:
+ value (str): The name of the current test.
+ """
+ self._prop['test_name'] = value
@property
- def root_branch(self):
- val = self._root
- return checkdir(val)
+ def repo_name(self):
+ """
+ Gets the name of the repository associated with the experiment.
+
+ Returns:
+ str: The name of the repository.
+ """
+ return self.project_name
@property
- def lib_root(self):
- return self.root_branch.as_posix()
+ def project_name(self):
+ """
+ Gets the name of the project associated with the experiment.
+
+ Returns:
+ str: The name of the project.
+ """
+ return os.path.basename(self.project_root)
@property
- def exp_branch(self):
- val = Path(exproot()).joinpath(self.exp_name)
- return checkdir(val)
+ def project_root(self):
+ """
+ Gets the path to the root directory of the project associated with the experiment.
+
+ Returns:
+ str: The path to the root directory of the project.
+ """
+ return local_dir()
@property
- def blob_branch(self):
- val = Path(blobroot()).joinpath(self.exp_name, self.test_name)
- return checkdir(val)
+ def properties(self):
+ """
+ Gets a dictionary containing all properties of the experiment.
+
+ Returns:
+ dict: A dictionary containing all properties of the experiment.
+ """
+ return self._prop
@property
- def progress_branch(self):
- val = Path(progressroot())
- return checkdir(val)
+ def metric(self):
+ """
+ Gets a dictionary containing all metrics of the experiment.
+
+ Returns:
+ Metric: A dictionary containing all metrics of the experiment.
+ """
+ if self._metric is None:
+ self._metric = Metric(self.mk_ipath('metric.pkl'))
+ self._metric.dump_metrics = self._trigger_change(self._metric.dump_metrics)
+ self._metric.dump_metric = self._trigger_change(self._metric.dump_metric)
+ return self._metric
@property
- def test_branch(self):
- val = self.exp_branch.joinpath(self.test_name)
- return checkdir(val)
+ def paths(self) -> dict:
+ """
+ Gets a dictionary containing the paths to various directories associated with the experiment.
- def dump_progress(self, ratio: float, update_from=None):
- res = {'ratio': ratio}
- if update_from is None:
- res['update_from'] = update_from
- self.dump_info('progress', res, append=True)
+ Returns:
+ dict: A dictionary containing the paths to various directories associated with the experiment.
+ """
+ return {
+ 'info_root': self._prop['paths'].get('info_root', exproot()),
+ 'cache_root': self._prop['paths'].get('cache_root', cache_dir()),
+ 'blob_root': self._prop['paths'].get('blob_root', blobroot()),
+ }
- def dump_info(self, key: str, info: dict, append=False, info_dir='info', set_prop=True):
- fn = self.test_file(f'{key}.json', info_dir)
- if append:
- old_info = self.load_info(key, info_dir=info_dir)
- old_info.update(info)
- info = old_info
- if set_prop:
- self.set_prop(key, info)
- io.dump_json(info, fn)
+ @property
+ def is_alive(self):
+ """
+ Determines whether the process associated with the experiment is still running.
- def load_info(self, key: str, info_dir='info'):
- fn = self.test_file(f'{key}.json', info_dir)
- if not os.path.exists(fn):
- return {}
- return io.load_json(fn)
+ Returns:
+ bool: True if the process is still running, False otherwise.
+ """
+ pinfo = self.properties['pinfo']
- def dump_string(self, key: str, info: str):
- fn = self.test_file(f'{key}.str', 'text')
- io.dump_text(info, fn)
- self.set_prop(key, info)
+ hash_obj = runtime_pid_obj(pinfo['pid'])
+ if hash_obj is None:
+ return False
- def load_string(self, key: str):
- fn = self.test_file(f'{key}.str', 'text')
- if not os.path.exists(fn):
- return ''
- return io.load_text(fn)
+ return pid_hash(hash_obj) == pinfo['hash']
@property
- def tags(self):
- tags = {}
- for path in self.test_branch.joinpath('tags').glob('tag.*.json'):
- ptags = io.load_json(path.as_posix()) # type: dict
- tags.setdefault(path.suffixes[0].strip('.'), []).extend(ptags.keys())
- return tags
+ def exec_argv(self):
+ """
+ Gets the arguments used to execute the script associated with the experiment.
+
+ Returns:
+ List[str]: A list of arguments used to execute the script.
+ """
+ execute_info = self.properties.get('execute')
+ try:
+ return [os.path.basename(execute_info['exec_bin']), *execute_info['exec_argv']]
+ except:
+ return []
+
+ def _trigger_change(self, func):
+ # test_root update some files
+ @wraps(func)
+ def inner(*args, **kwargs):
+ fn = self.heartbeat_fn
+ io.dump_text(self.info_dir, fn)
+ func(*args, **kwargs)
+
+ return inner
+
+ @classmethod
+ def _create_test_name(cls, exp_dir):
+ """
+ Generates a unique test name based on the current date and time.
+ regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t
+
+ Returns:
+ str: The generated test name.
+ """
+ from lumo.proc.date import timehash
+ from ..utils.fmt import strftime
+ fs = os.listdir(exp_dir)
+ date_str = strftime('%y%m%d')
+ fs = [i for i in fs if i.startswith(date_str)]
+ _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t"
+ return _test_name
+
+ def get_prop(self, key, default=None):
+ """
+ Gets the value of a property of the experiment.
+
+ Args:
+ key (str): The name of the property to get.
+ default (Any, optional): The default value to return if the property does not exist. Defaults to None.
- def add_tag(self, tag: str, name_space: str = 'default'):
- self.dump_info(f'tag.{name_space}', {
- tag: None
- }, append=True, info_dir='tags', set_prop=False)
+ Returns:
+ Any: The value of the property, or the default value if the property does not exist.
+ """
+ return self._prop.get(key, default)
- def exp_file(self, filename, *args):
+ def has_prop(self, key):
"""
+ Determines whether the experiment has a certain property.
Args:
- filename:
- *args:
- mkdir:
+ key (str): The name of the property to check for.
Returns:
+ bool: True if the experiment has the property, False otherwise.
+ """
+ return key in self._prop
+ def set_prop(self, key, value):
"""
- parent = self.exp_branch.joinpath(*args)
- return checkdir(parent).joinpath(filename).as_posix()
+ Sets a property of the experiment.
- def test_file(self, filename, *args):
- parent = self.test_branch.joinpath(*args)
- return checkdir(parent).joinpath(filename).as_posix()
+ Args:
+ key (str): The name of the property to set.
+ value (Any): The value to set the property to.
+ """
+ self._prop[key] = value
- def exp_dir(self, *args):
+ def dump_progress(self, ratio: float, update_from=None):
"""
+ Saves progress information about the experiment.
Args:
- filename:
- *args:
- mkdir:
+ ratio (float): The progress ratio as a number between 0 and 1.
+ update_from: The process from which the progress update came from.
+ """
+ res = {'ratio': max(min(ratio, 1), 0)}
+ if update_from is None:
+ res['update_from'] = update_from
+ res['last_edit_time'] = strftime()
+ self.dump_info('progress', res, append=True)
- Returns:
+ def dump_info(self, key: str, info: Any, append=False):
+ """
+ Saves information about the experiment to a file.
+ Args:
+ key (str): The key under which the information will be stored.
+ info (Any): The information to store.
+ append (bool, optional): Whether to append to the file or overwrite it. Defaults to False.
"""
- parent = self.exp_branch.joinpath(*args)
- return checkdir(parent).as_posix()
+ fn = self.mk_ipath('info', f'{key}.json')
+ if append:
+ old_info = self.load_info(key)
+ old_info.update(info)
+ info = old_info
- def root_file(self, filename, *args):
- parent = self.root_branch.joinpath(*args)
- return checkdir(parent).joinpath(filename).as_posix()
+ self.set_prop(key, info)
+ io.dump_json(info, fn)
- def root_dir(self, *args):
+ def load_info(self, key: str):
"""
+ Loads information about the experiment from a file.
Args:
- filename:
- *args:
- mkdir:
+ key (str): The key under which the information is stored.
Returns:
-
+ Any: The information stored under the specified key.
"""
- parent = self.root_branch.joinpath(*args)
- return checkdir(parent).as_posix()
+ fn = self.mk_ipath('info', f'{key}.json')
+ if not os.path.exists(fn):
+ return {}
+ try:
+ return io.load_json(fn)
+ except ValueError as e:
+ return {}
- def test_dir(self, *args):
- parent = self.test_branch.joinpath(*args)
- return checkdir(parent).as_posix()
+ def load_note(self):
+ fn = self.mk_ipath('note.md')
+ if os.path.exists(fn):
+ return io.load_text(fn)
+ return ''
- def blob_file(self, filename, *args):
- parent = self.blob_branch.joinpath(*args)
- return checkdir(parent).joinpath(filename).as_posix()
+ def dump_tags(self, *tags):
+ self.dump_info('tags', tags)
- def progress_file(self, filename):
- return self.progress_branch.joinpath(filename).as_posix()
+ def dump_note(self, note: str):
+ fn = self.mk_ipath('note.md')
+ self.set_prop('note', note)
+ io.dump_text(note, fn)
- def blob_dir(self, *args):
+ def dump_string(self, key: str, info: str, append=False):
"""
+ Saves a string to a file.
Args:
- filename:
- *args:
- mkdir:
+ key (str): The key under which the string will be stored.
+ info (str): The string to store.
+ """
+ fn = self.mk_ipath('text', f'{key}.str')
+ io.dump_text(info, fn, append=append)
+ if not append:
+ self.set_prop(key, info)
- Returns:
+ def load_string(self, key: str):
+ """
+ Loads a string from a file.
+ Args:
+ key (str): The key under which the string is stored.
+
+ Returns:
+ str: The string stored under the specified key.
"""
- parent = self.blob_branch.joinpath(*args)
- return checkdir(parent).as_posix()
+ fn = self.mk_ipath('text', f'{key}.str')
+ if not os.path.exists(fn):
+ return ''
+ return io.load_text(fn)
- def __enter__(self):
- self.start()
- return self
+ def dump_metric(self, key, value, cmp: str, flush=True, **kwargs):
+ return self.metric.dump_metric(key, value, cmp, flush, **kwargs)
- def __exit__(self, exc_type, exc_val, exc_tb):
- extra = {}
- if exc_type is not None:
- exc_type = traceback.format_exception_only(exc_type, exc_val)[-1].strip()
- extra['exc_type'] = exc_type
- extra['end_info'] = str(exc_type)
- extra['end_code'] = 1
- extra['exc_stack'] = "".join(traceback.format_exception(exc_type, exc_val, exc_tb))
- self.end(**extra)
+ def dump_metrics(self, dic: dict, cmp: str):
+ return self.metric.dump_metrics(dic, cmp)
- @call_on_main_process_wrap
- def add_exit_hook(self, func):
- import atexit
- def exp_func():
- func(self)
+ @property
+ def info_root(self):
+ return self.paths['info_root']
- atexit.register(exp_func)
+ @property
+ def cache_root(self):
+ return self.paths['cache_root']
+
+ @property
+ def blob_root(self):
+ return self.paths['blob_root']
+
+ @property
+ def pid_fn(self):
+ fn = os.path.join(self.cache_root, 'pid', self.exp_name, f'{self.test_name}.pid')
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
+ return fn
+
+ @property
+ def heartbeat_fn(self):
+ fn = os.path.join(self.cache_root, 'heartbeat', self.exp_name, f'{self.test_name}.hb')
+ os.makedirs(os.path.dirname(fn), exist_ok=True)
+ return fn
+
+ @property
+ def exp_dir(self):
+ d = os.path.join(self.info_root, self.exp_name)
+ os.makedirs(d, exist_ok=True)
+ return d
+
+ @property
+ def info_dir(self):
+ d = os.path.join(self.info_root, self.exp_name, self.test_name)
+ os.makedirs(d, exist_ok=True)
+ return d
+
+ @property
+ def cache_dir(self):
+ d = os.path.join(self.cache_root, self.exp_name, self.test_name)
+ os.makedirs(d, exist_ok=True)
+ return d
+
+ @property
+ def blob_dir(self):
+ d = os.path.join(self.blob_root, self.exp_name, self.test_name)
+ os.makedirs(d, exist_ok=True)
+ return d
+
+ def _mk_path(self, *path: str, is_dir) -> str:
+ path = os.path.join(*path)
+ if is_dir:
+ os.makedirs(path, exist_ok=True)
+ else:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ return path
+
+ def mk_ipath(self, *path, is_dir=False):
+ return self._mk_path(self.info_dir, *path, is_dir=is_dir)
+
+ def mk_cpath(self, *path, is_dir=False):
+ return self._mk_path(self.cache_dir, *path, is_dir=is_dir)
+
+ def mk_bpath(self, *path, is_dir=False):
+ return self._mk_path(self.blob_dir, *path, is_dir=is_dir)
+
+ def mk_rpath(self, *path, is_dir=False):
+ return self._mk_path(libhome(), *path, is_dir=is_dir)
+
+ @classmethod
+ @property
+ def Class(cls):
+ return cls
+
+ def rerun(self, arg_list: List[str]):
+ """rerun this test in another"""
+ # self.properties['']
+ new_test_name = self._create_test_name(self.exp_dir)
+ new_exp = Experiment(self.exp_name, test_name=new_test_name)
+ self.dump_info('deprecated', {'rerun_at': {new_exp.test_name: True}}, append=True)
+ old_rerun_info = self.properties.get('rerun', None)
+ count = 1
+ if old_rerun_info is not None:
+ count += old_rerun_info['count']
+ new_exp.dump_info('rerun', {'from': self.test_name, 'repeat': count})
+ from lumo.utils.subprocess import run_command
+ old_exec = self.properties['execute']
+ command = ' '.join([old_exec['exec_bin'], old_exec['exec_file'], *old_exec['exec_argv'], *arg_list])
+ env = os.environ.copy()
+ env[Experiment.ENV_TEST_NAME_KEY] = new_exp.test_name
+ return run_command(command, cwd=old_exec['cwd'], env=env)
@call_on_main_process_wrap
def initial(self):
- self.add_tag(self.__class__.__name__, 'exp_type')
- self.dump_progress(0)
+ """
+ Initializes the experiment by setting up progress, information, and PID tracking.
+ """
+ self.dump_info('exp_name', self.exp_name)
+ self.dump_info('test_name', self.test_name)
+ self.dump_info('paths', self.paths)
+
self.dump_info('execute', {
'repo': self.project_root,
'cwd': os.getcwd(),
@@ -328,144 +639,159 @@ def initial(self):
'obj': runtime_pid_obj(),
})
+ # register start
+ self.dump_info('progress', {'start': strftime(), 'finished': False}, append=True)
+ self.dump_progress(0)
# register progress
- io.dump_text(self.test_root, self.progress_file(f'{os.getpid()}'))
+ io.dump_text(self.info_dir, self.pid_fn)
@call_on_main_process_wrap
def start(self):
- if self.get_prop('start', False):
+ """
+ Starts the experiment.
+ """
+ if self.properties.get('progress', None) is not None:
return
self.initial()
self.set_prop('start', True)
- for hook in self._hooks.values(): # type: ExpHook
+ for hook in self._hooks.values(): # type: BaseExpHook
hook.on_start(self)
return self
@call_on_main_process_wrap
def end(self, end_code=0, *args, **extra):
- if not self.get_prop('start', False):
+ """
+ Ends the experiment.
+
+ Args:
+ end_code (int): The exit code to set for the experiment.
+ *args: Additional arguments to pass to the end hooks.
+ **extra: Additional keyword arguments to pass to the end hooks.
+ """
+ if not self.is_alive:
+ return
+ if not self.properties.get('progress', None) is None:
return
- if self.get_prop('end', False):
+ if self.properties['progress'].get('end', False):
return
- self.dump_progress(1)
+
self.set_prop('end', True)
- for hook in self._hooks.values(): # type: ExpHook
+ if end_code == 0:
+ self.dump_progress(1)
+
+ self.dump_info('progress', {'end': strftime(), 'finished': end_code == 0}, append=True)
+ for hook in self._hooks.values(): # type: BaseExpHook
hook.on_end(self, end_code=end_code, *args, **extra)
return self
- @property
- def repo_name(self):
- """repository name"""
- return self.project_name
-
- @property
- def project_name(self):
- """same as repository name, directory name of project root"""
- return os.path.basename(self.project_root)
-
- @property
- def project_root(self):
- return local_dir()
-
- @property
- def exp_root(self):
- """path to multiple tests of this experiment"""
- return self.exp_branch.as_posix()
-
- @property
- def test_root(self):
- """path to record information of one experiment"""
- return self.test_branch.as_posix()
+ @call_on_main_process_wrap
+ def set_hook(self, hook: BaseExpHook):
+ """
+ Registers a hook to be executed during the experiment.
- @property
- def blob_root(self):
- """path to storing big binary files"""
- return self.blob_branch.as_posix()
+ Args:
+ hook (BaseExpHook): The hook to register.
+ """
+ if not glob.get(hook.config_name, True):
+ self.dump_info('hooks', {
+ hook.__class__.__name__: {'loaded': False, 'msg': 'disabled by config'}
+ }, append=True)
+ return self
+ else:
+ hook.regist(self)
+ self.dump_info('hooks', {
+ hook.__class__.__name__: {'loaded': True, 'msg': ''}
+ }, append=True)
+ self.logger.info(f'Register {hook}.')
+ self._hooks[hook.__class__.__name__] = hook
+ return self
- def __getitem__(self, item):
- return self._prop[item]
+ @call_on_main_process_wrap
+ def add_exit_hook(self, func):
+ """
+ Registers a function to be called when the program exits.
- def __setitem__(self, key, value):
- self._prop[key] = value
+ Args:
+ func (callable): The function to register.
+ """
+ import atexit
+ def exp_func():
+ """Function executed before process exit."""
+ func(self)
- def get_prop(self, key, default=None):
- return self._prop.get(key, default)
+ atexit.register(exp_func)
- def has_prop(self, key):
- return key in self._prop
+ @classmethod
+ def from_cache(cls, dic: dict):
+ paths = dic.pop('paths', {})
+ _ = dic.pop('metrics')
+ self = cls(exp_name=dic['exp_name'], test_name=dic['test_name'], paths=paths)
+ self._prop.update(dic)
+ return self
- def set_prop(self, key, value):
- self._prop[key] = value
+ @classmethod
+ def from_disk(cls, path):
+ """
+ Creates an Experiment object from a test root directory on disk.
- @property
- def properties(self):
- return self._prop
+ Args:
+ path (str): The path to the test root directory.
- @property
- def paths(self) -> dict:
- return {
- 'root': self.root_branch.as_posix(),
- 'exp_root': self.exp_root,
- 'test_root': self.test_root,
- 'blob_root': self.blob_root,
- }
+ Returns:
+ Experiment: An Experiment object created from the test root directory.
- @property
- def enable_properties(self) -> set:
- return set(self._prop.keys())
+ Raises:
+ ValueError: If the path is not a valid test root directory.
+ """
+ from .finder import is_test_root
+ if not is_test_root(path):
+ raise ValueError(f'{path} is not a valid test_root')
+ path = os.path.abspath(path)
+ exp_dir = os.path.dirname(path)
+
+ paths_fn = os.path.join(path, 'info', f'paths.json')
+ if os.path.exists(paths_fn):
+ try:
+ paths = io.load_json(paths_fn)
+ except ValueError as e:
+ paths = {}
+ else:
+ paths = {}
- @call_on_main_process_wrap
- def set_hook(self, hook: ExpHook):
- hook.regist(self)
- if not glob.get(hook.config_name, True):
- self.dump_info(hook.name, {
- 'code': -1,
- 'msg': f'{hook.name} disabled'
- })
- return self
- self.logger.info(f'Register {hook}.')
- self._hooks[hook.__class__.__name__] = hook
- self.add_tag(hook.__class__.__name__, 'hooks')
- return self
+ self = cls(os.path.basename(exp_dir), test_name=os.path.basename(path), paths=paths)
- def load_prop(self):
- for f in os.listdir(self.test_dir('info')):
+ # load prop
+ for f in os.listdir(self.mk_ipath('info', is_dir=True)):
key = os.path.splitext(f)[0]
self.set_prop(key, self.load_info(key))
- for f in os.listdir(self.test_dir('text')):
+ for f in os.listdir(self.mk_ipath('text', is_dir=True)):
key = os.path.splitext(f)[0]
self.set_prop(key, self.load_string(key))
- @classmethod
- def from_disk(cls, path):
- from .finder import is_test_root
- if not is_test_root(path):
- raise ValueError(f'{path} is not a valid test_root')
+ self.set_prop('note', self.load_note())
- test_root = Path(path)
- root = test_root.parent.parent.parent.as_posix()
- self = cls(test_root.parent.name, root=root)
- self._test_name = test_root.name
- self.load_prop()
return self
- @property
- def exec_argv(self):
- execute_info = self.get_prop('execute')
- try:
- return [os.path.basename(execute_info['exec_bin']), *execute_info['exec_argv']]
- except:
- return []
-
- def __repr__(self):
- return f'{self.exp_name}->({self.test_name})'
+ def cache(self):
+ return {
+ **self.properties,
+ 'metrics': self.metric.value,
+ }
- def __str__(self):
- return self.__repr__()
+ def dict(self):
+ return {
+ **self.properties,
+ 'is_alive': self.is_alive,
+ 'metrics': self.metric.value,
+ }
class SimpleExperiment(Experiment):
+ """
+ A simple to use experiment subclass that extends the base `Experiment` class and sets up some useful hooks to
+ execute before and after the experiment.
+ """
def __init__(self, exp_name: str, root=None):
super().__init__(exp_name, root)
@@ -475,5 +801,5 @@ def __init__(self, exp_name: str, root=None):
self.set_hook(exphook.GitCommit())
self.set_hook(exphook.RecordAbort())
self.set_hook(exphook.Diary())
- self.set_hook(exphook.TimeMonitor())
+ # self.set_hook(exphook.TimeMonitor())
self.set_hook(exphook.FinalReport())
diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py
index 41a21b54..68500fbb 100644
--- a/src/lumo/exp/exphook.py
+++ b/src/lumo/exp/exphook.py
@@ -15,10 +15,12 @@
from lumo.utils.exithook import wrap_before
from lumo.utils.fmt import strftime, indent_print
from . import Experiment
-from .base import ExpHook as BaseExpHook
+from .base import BaseExpHook as BaseExpHook
class ExpHook(BaseExpHook):
+ """A base class of hook for experiments that can be registered with an experiment."""
+
def regist(self, exp: Experiment):
self.exp = exp
@@ -32,6 +34,11 @@ def on_newpath(self, exp: Experiment, *args, **kwargs): pass
class LastCmd(ExpHook):
+ """A hook to save the last command executed in an experiment.
+
+ This hook saves the last command executed in an experiment to a shell script file in a specified directory. The saved
+ file can be used to re-run the experiment with the same command.
+ """
configs = {'HOOK_LASTCMD_DIR': os.getcwd()}
def on_start(self, exp: Experiment, *args, **kwargs):
@@ -55,20 +62,25 @@ def on_start(self, exp: Experiment, *args, **kwargs):
os.chmod(fn, st.st_mode | stat.S_IEXEC)
-class PathRecord(ExpHook):
-
- def on_newpath(self, exp: Experiment, *args, **kwargs):
- super().on_newpath(exp, *args, **kwargs)
+# class PathRecord(ExpHook):
+#
+# def on_newpath(self, exp: Experiment, *args, **kwargs):
+# super().on_newpath(exp, *args, **kwargs)
class Diary(ExpHook):
+ """A hook for logging experiment information to a diary file."""
+
def on_start(self, exp: Experiment, *args, **kwargs):
super().on_start(exp, *args, **kwargs)
- with open(exp.root_file(f'{strftime("%y%m%d")}.log', 'diary'), 'a') as w:
- w.write(f'{strftime("%H:%M:%S")}, {exp.test_root}\n')
+ # with open(exp.root_file(f'{strftime("%y%m%d")}.log', 'diary'), 'a') as w:
+ # w.write(f'{strftime("%H:%M:%S")}, {exp.test_root}\n')
class RecordAbort(ExpHook):
+ """A hook to record and handle experiment aborts.
+ """
+
def regist(self, exp: Experiment):
super().regist(exp)
wrap_before(self.exc_end)
@@ -84,29 +96,29 @@ def exc_end(self, exc_type, exc_val, exc_tb):
)
-class TimeMonitor(ExpHook):
- def _create_agent(self, exp: Experiment):
- from lumo.exp import agent
- cmd = [
- sys.executable, '-m', agent.__spec__.name,
- f"--state_key=state",
- f"--pid={os.getpid()}",
- f"--exp_name={exp.exp_name}",
- f"--test_name={exp.test_name}",
- f"--test_root={exp.test_root}",
- # f"--params={sys.argv}" # TODO add sys.argv
- ]
- subprocess.Popen(' '.join(cmd),
- stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True,
- start_new_session=True)
-
- def on_start(self, exp: Experiment, *args, **kwargs):
- super().on_start(exp)
- self._create_agent(exp)
- exp.dump_info('state', {
- 'start': strftime(),
- 'end': strftime()
- })
+# class TimeMonitor(ExpHook):
+# def _create_agent(self, exp: Experiment):
+# from lumo.exp import agent
+# cmd = [
+# sys.executable, '-m', agent.__spec__.name,
+# f"--state_key=state",
+# f"--pid={os.getpid()}",
+# f"--exp_name={exp.exp_name}",
+# f"--test_name={exp.test_name}",
+# f"--test_root={exp.test_root}",
+# # f"--params={sys.argv}"
+# ]
+# subprocess.Popen(' '.join(cmd),
+# stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True,
+# start_new_session=True, cwd=os.getcwd(), env=os.environ)
+#
+# def on_start(self, exp: Experiment, *args, **kwargs):
+# super().on_start(exp)
+# self._create_agent(exp)
+# exp.dump_info('state', {
+# 'start': strftime(),
+# 'end': strftime()
+# })
class GitCommit(ExpHook):
@@ -121,7 +133,6 @@ def on_start(self, exp: Experiment, *args, **kwargs):
from lumo.utils.repository import git_enable, git_commit, git_dir
from lumo.utils.ast import analyse_module_dependency
- from lumo.utils.hash import hash_iter
import inspect
if not git_enable():
@@ -149,8 +160,8 @@ def on_start(self, exp: Experiment, *args, **kwargs):
except OSError:
pass
- dep_hash = hash_iter(*dep_source)
- commit_ = git_commit(key='lumo', info=exp.test_root, filter_files=filter_files)
+ dep_hash = hash(dep_source)
+ commit_ = git_commit(key='lumo', info=exp.info_dir, filter_files=filter_files)
if commit_ is None:
exp.dump_info('git', {
@@ -165,32 +176,48 @@ def on_start(self, exp: Experiment, *args, **kwargs):
'repo': exp.project_root,
'dep_hash': dep_hash,
})
-
- file = exp.root_file(hash(exp.project_root), 'repos')
+ file = exp.mk_rpath('repos', hash(exp.project_root))
exps = {}
if os.path.exists(file):
exps = io.load_json(file)
res = exps.setdefault(exp.project_root, list())
- if exp.exp_root not in res:
- res.append(exp.exp_root)
+ if exp.exp_dir not in res:
+ res.append(exp.exp_dir)
io.dump_json(exps, file)
class LockFile(ExpHook):
+ """A class for locking dependencies for an experiment.
+ Locks the specified dependencies for the experiment and saves them to a file.
+ """
def on_start(self, exp: Experiment, *args, **kwargs):
- exp.dump_info('lock', get_lock('torch', 'numpy',
- 'joblib',
- 'psutil',
- 'decorator',
- 'torch',
- 'numpy',
- 'accelerate',
- 'hydra',
- 'omegaconf', ))
+ basic = get_lock('lumo',
+ 'joblib',
+ 'fire',
+ 'psutil',
+ 'accelerate',
+ 'hydra',
+ 'omegaconf',
+ 'decorator',
+
+ 'numpy',
+ 'torch',
+ )
+ if basic['torch'] is not None:
+ import torch
+ if torch.cuda.is_available():
+ basic['torch.version.cuda'] = torch.version.cuda
+
+ exp.dump_info('lock', basic)
class FinalReport(ExpHook):
+ """A class for generating a final report for an experiment.
+
+ Prints the experiment's properties, tags, paths, and execute command.
+ """
+
def on_end(self, exp: Experiment, end_code=0, *args, **kwargs):
# if end_code == 0:
print('-----------------------------------')
@@ -198,10 +225,6 @@ def on_end(self, exp: Experiment, end_code=0, *args, **kwargs):
print('Properties:')
indent_print(pformat(exp.properties))
- print('Tags:')
- indent_print(pformat(exp.tags))
- print('Use paths:')
- indent_print(pformat(exp.paths))
print('Execute:')
indent_print(' '.join(exp.exec_argv))
print('-----------------------------------')
diff --git a/src/lumo/exp/finder.py b/src/lumo/exp/finder.py
index 28c63b47..53989174 100644
--- a/src/lumo/exp/finder.py
+++ b/src/lumo/exp/finder.py
@@ -9,7 +9,7 @@
"""
from pprint import pformat
import os
-from typing import List, Dict
+from typing import List, Dict, Any
from lumo.proc.path import libhome, exproot, metricroot
from lumo.utils.fmt import indent_print
@@ -18,32 +18,86 @@
from . import Experiment
-def list_experiment_paths(exp_root=None):
+def list_experiment_paths(exp_root=None) -> List[str]:
+ """
+ Returns a list of experiment paths under exp_root directory.
+
+ Args:
+ exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory.
+
+ Returns:
+ A list of experiment paths.
+ """
if exp_root is None:
exp_root = exproot()
return [os.path.join(exp_root, i) for i in os.listdir(exp_root)]
-def _get_exp_name(exp_path: str):
+def _get_exp_name(exp_path: str) -> str:
+ """
+ Returns the name of the experiment directory.
+
+ Args:
+ exp_path: The path to the experiment directory.
+
+ Returns:
+ The name of the experiment directory.
+ """
return os.path.basename(exp_path.rstrip('/'))
def list_all(exp_root=None) -> Dict[str, List[Experiment]]:
+ """
+ Returns a dictionary of all experiments under exp_root directory.
+
+ Args:
+ exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory.
+
+ Returns:
+ A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects.
+ """
return {
_get_exp_name(exp_path): retrieval_tests_from_experiment(exp_path)
for exp_path in list_experiment_paths(exp_root)
}
-def retrieval_tests_from_experiment(exp_path) -> List[Experiment]:
+def retrieval_tests_from_experiment(exp_path: str) -> List[Experiment]:
+ """
+ Returns a list of Experiment objects found in the specified experiment directory.
+
+ Args:
+ exp_path: The path to the experiment directory.
+
+ Returns:
+ A list of Experiment objects.
+ """
return [retrieval_experiment(os.path.join(exp_path, f)) for f in os.listdir(exp_path)]
-def list_test_names_from_experiment(experiment_name):
+def list_test_names_from_experiment(experiment_name: str) -> List[str]:
+ """
+ Returns a list of test names found in the specified experiment directory.
+
+ Args:
+ experiment_name: The name of the experiment directory.
+
+ Returns:
+ A list of test names.
+ """
return os.listdir(os.path.join(exproot(), experiment_name))
-def find_path_from_test_name(test_name: str):
+def find_path_from_test_name(test_name: str) -> str:
+ """
+ Returns the path of the specified test name.
+
+ Args:
+ test_name: The name of the test.
+
+ Returns:
+ The path of the test, or None if not found.
+ """
if not is_test_name(test_name):
return None
@@ -58,21 +112,42 @@ def find_path_from_test_name(test_name: str):
return None
-def is_test_name(test_name: str):
+def is_test_name(test_name: str) -> bool:
"""
- ^[0-9]{6}.[0-9]{3}.[a-z0-9]{2}t$
+ Determines if the specified string is a valid test name.
+
+ Args:
+ test_name: The string to check.
+
+ Returns:
+ True if the string is a valid test name, False otherwise.
"""
return re.search(r'^\d{6}\.\d{3}\.[a-z\d]{2}t$', test_name) is not None
-def is_test_root(path: str):
+def is_test_root(path: str) -> bool:
+ """
+ Determines if the specified path is a valid test root.
+
+ Args:
+ path: The path to check.
+
+ Returns:
+ True if the path is a valid test root, False otherwise.
+ """
test_name = os.path.basename(path.rstrip('/'))
return is_test_name(test_name)
-def retrieval_test_root(test_flag: str):
+def retrieval_test_root(test_flag: str) -> str:
"""
- test_flag can be a name like `230214.037.62t` or path like `path/to/230214.037.62t`
+ Returns the test root directory for the specified test name or test root.
+
+ Args:
+ test_flag: The test name or test root.
+ like `230214.037.62t` or path like `path/to/230214.037.62t`
+ Returns:
+ The test root directory, or None if not found.
"""
if is_test_name(test_flag):
test_root = find_path_from_test_name(test_flag)
@@ -86,7 +161,21 @@ def retrieval_test_root(test_flag: str):
return test_root
-def retrieval_experiment(test_name=None, test_root: str = None):
+def retrieval_experiment(test_name=None, test_root: str = None) -> Experiment:
+ """
+ Loads an Experiment object from disk for the given test name or test root.
+
+ Args:
+ test_name (str, optional): The name of the test to load. If not provided,
+ the test root directory must be provided instead. Defaults to None.
+ test_root (str, optional): The root directory of the test to load. If not
+ provided, the root directory is determined from the test name using
+ the retrieval_test_root function. Defaults to None.
+
+ Returns:
+ Optional[Experiment]: The loaded Experiment object, or None if the test
+ root cannot be determined or the Experiment cannot be loaded from disk.
+ """
if test_root is None:
test_root = retrieval_test_root(test_name)
if test_root is None:
@@ -96,6 +185,13 @@ def retrieval_experiment(test_name=None, test_root: str = None):
def summary_experiment(test_name: str = None, test_root: str = None):
+ """
+ Prints a summary of the experiment specified by test_name or test_root.
+
+ Args:
+ test_name: The name of the test.
+ test_root: The path to the test root directory.
+ """
if test_root is None:
if test_name is None:
raise ValueError()
@@ -117,7 +213,16 @@ def summary_experiment(test_name: str = None, test_root: str = None):
print('-----------------------------------')
-def format_experiment(exp: Experiment):
+def format_experiment(exp: Experiment) -> Dict[str, Any]:
+ """
+ Formats the Experiment object into a dictionary.
+
+ Args:
+ exp: An Experiment object.
+
+ Returns:
+ A dictionary of the Experiment properties, tags, paths, and execution arguments.
+ """
return {
'Properties': exp.properties,
'tags': exp.tags,
@@ -126,7 +231,16 @@ def format_experiment(exp: Experiment):
}
-def list_all_metrics(metric_root=None):
+def list_all_metrics(metric_root=None) -> Dict[str, List[str]]:
+ """
+ Returns a dictionary of all metrics found under metric_root directory.
+
+ Args:
+ metric_root: The root directory to search for metrics. Default is None, which uses the default metric root directory.
+
+ Returns:
+ A dictionary of all metrics, where the keys are the metric names and the values are lists of corresponding metric files.
+ """
if metric_root is None:
metric_root = metricroot()
diff --git a/src/lumo/exp/metric.py b/src/lumo/exp/metric.py
new file mode 100644
index 00000000..abf5f618
--- /dev/null
+++ b/src/lumo/exp/metric.py
@@ -0,0 +1,60 @@
+import os
+from lumo.utils import safe_io as IO
+
+
+class Metric:
+ """
+ """
+
+ def __init__(self, metric_fn, persistent=True):
+ os.makedirs(os.path.dirname(os.path.abspath(metric_fn)), exist_ok=True)
+ self.fn = metric_fn
+ self._metric = {}
+ if os.path.exists(metric_fn):
+ self._metric = IO.load_pkl(metric_fn)
+ self.persistent = persistent
+
+ @property
+ def value(self):
+ """
+ A property that returns the metric values of the row.
+
+ Returns:
+ dict: A dictionary containing the metric values of the row.
+ """
+ return self._metric
+
+ def dump_metric(self, key, value, cmp: str, flush=True, **kwargs):
+ dic = self.value
+ older = dic.setdefault(key, None)
+
+ update = False
+ if older is None or cmp is None:
+ update = True
+ else:
+ if cmp == 'max':
+ if older < value:
+ update = True
+ elif cmp == 'min':
+ if older > value:
+ update = True
+ else:
+ raise NotImplementedError()
+
+ if update:
+ dic[key] = value
+ for kk, vv in kwargs.items():
+ dic[kk] = vv
+
+ if flush:
+ self.flush()
+ return value
+
+ def dump_metrics(self, dic: dict, cmp: str):
+ for k, v in dic.items():
+ self.dump_metric(k, v, cmp)
+
+ def flush(self):
+ """Writes the value of the row to a file."""
+ if self.persistent:
+ IO.dump_pkl(self.value, self.fn)
diff --git a/src/lumo/exp/watch.py b/src/lumo/exp/watch.py
new file mode 100644
index 00000000..e2dc8675
--- /dev/null
+++ b/src/lumo/exp/watch.py
@@ -0,0 +1,536 @@
+"""
+Watcher 可以在运行实验后在 jupyter 或者网页上展示正在运行和已经运行结束的实验(按时间顺序?)
+以及可以简化记录实验的烦恼
+
+现在的核心痛点是
+ - [ ] 所有元信息都有了,但是找不到哪个实验是哪个实验
+ - [ ] 同时跑的多个实验有一个失败了,重跑时会混淆,或许需要一种覆盖手段 ->
+ - > 怎么 rerun?
+ lumo rerun test_name √
+ lumo note html (用 streamlit 之类的生成动态网页)
+ lumo note cmd (类似 top 的视角,按时间顺序排列)
+- > rerun 将已经跑的实验 move
+
+可以代替 analysis 的作用。主要有
+
+-> 按照 progress 目录,获取所有的实验
+-> 根据获取的实验,按顺序记录
+-> 每次只记录
+
+"""
+import numbers
+import os.path
+from typing import List, Dict, overload
+from pprint import pformat
+import pandas as pd
+from dbrecord import PDict
+from datetime import datetime
+from operator import gt, ge, le, lt, eq, ne
+
+from lumo.proc.path import progressroot, exproot, dbroot, cache_dir
+from .experiment import Experiment
+from lumo.utils import safe_io as IO
+from lumo.utils.fmt import format_timedelta, strptime, strftime
+
+PID_ROOT = os.path.join(progressroot(), 'pid')
+HB_ROOT = os.path.join(progressroot(), 'hb')
+EXP_ROOT = os.path.join(progressroot())
+
+styles = {
+ 'row-radio': """""",
+ 'widget-box': """
+
+
+ """
+}
+
+
+def in_(ser, value):
+ """pandas operation"""
+ return ser.apply(lambda x: x in value)
+
+
+def not_in_(ser, value):
+ """pandas operation"""
+ return ser.apply(lambda x: x not in value)
+
+
+# supported conditions
+mapping = {
+ '>=': ge,
+ '<=': le,
+ '==': eq,
+ '!=': ne,
+ '>': gt,
+ '<': lt,
+ 'in': in_,
+ 'notin': not_in_,
+}
+
+
+class Condition:
+ def __init__(self, name: str = None, value=None, op=None):
+ self.name = name
+ self.value = value
+ self.op = op
+
+ def __getattr__(self, item):
+ return Condition(item)
+
+ def __getitem__(self, item):
+ return Condition(item)
+
+ def __neg__(self):
+ self.drop = True
+ return self
+
+ def __ge__(self, other):
+ if other is None:
+ raise AssertionError()
+ self.value = other
+ self.op = ">="
+ return self
+
+ def __le__(self, other):
+ if other is None:
+ raise AssertionError()
+ self.value = other
+ self.op = "<="
+ return self
+
+ def __eq__(self, other):
+ self.value = other
+ self.op = "=="
+ return self
+
+ def __ne__(self, other):
+ self.value = other
+ self.op = "!="
+ return self
+
+ def __gt__(self, other):
+ if other is None:
+ raise AssertionError()
+ self.value = other
+ self.op = ">"
+ return self
+
+ def __lt__(self, other):
+ assert other is not None
+ self.value = other
+ self.op = "<"
+ return self
+
+ def __repr__(self):
+ return f'C({self.name} {self.op} {self.value})'
+
+ def in_(self, lis):
+ """condition of `in` operation"""
+ self.op = 'in'
+ self.value = set(lis)
+ return self
+
+ def not_in_(self, lis):
+ """condition of `.duplicated(value) == False` operation"""
+ self.op = 'notin'
+ self.value = set(lis)
+ return self
+
+ def mask(self, df):
+ names = self.name.split('.')
+ value = df
+ for i in names:
+ if isinstance(value, pd.DataFrame):
+ value = value[i]
+ else:
+ value = df.apply(lambda x: x[i])
+ return mapping[self.op](value, self.value)
+
+ def apply(self, df):
+ return df[self.mask(df)]
+
+
+C = Condition()
+
+
+class Watcher:
+ """List and watch experiments with time order
+
+ Cache test_information in
+ metrics/.sqlite
+ """
+
+ def __init__(self, exp_root=None, hb_root=None, pid_root=None, db_root=None):
+ if exp_root is None:
+ exp_root = os.path.join(exproot(), 'hb')
+
+ if hb_root is None:
+ hb_root = os.path.join(cache_dir(), 'heartbeat')
+
+ if pid_root is None:
+ pid_root = os.path.join(cache_dir(), 'pid')
+
+ if db_root is None:
+ db_root = dbroot()
+ self.db_root = db_root
+ self.exp_root = exp_root
+ self.hb_root = hb_root
+ self.pid_root = pid_root
+
+ def load(self):
+ res = {}
+ updates = {}
+ if not os.path.exists(self.hb_root):
+ return pd.DataFrame()
+ for root, dirs, fs in os.walk(self.hb_root):
+ if root == self.hb_root:
+ continue
+ for f in fs:
+ if f.endswith('hb'):
+ hb_file = os.path.join(root, f)
+ test_root = IO.load_text(hb_file)
+ try:
+ exp = Experiment.from_disk(test_root)
+ updates.setdefault(exp.exp_name, []).append(exp.cache())
+ except KeyboardInterrupt as e:
+ raise e
+ except:
+ continue
+
+ for exp_name, tests in updates.items():
+ dic = PDict(os.path.join(self.db_root, f'{exp_name}.sqlite'))
+ for test_name, test_prop in dic.items():
+ res[test_name] = test_prop
+
+ for test in tests:
+ dic[test['test_name']] = test
+ res[test['test_name']] = test
+ dic.flush()
+
+ df = pd.DataFrame(res.values())
+ df = df.sort_values(['exp_name', 'test_name'])
+ return df.reset_index(drop=True)
+
+ def progress(self, is_alive=True):
+ """return the alive process"""
+ res = []
+ for root, dirs, fs in os.walk(self.pid_root):
+ for f in fs:
+ if not f.endswith('.pid'):
+ continue
+ try:
+ test_root = IO.load_text(os.path.join(root, f))
+ exp = Experiment.from_disk(test_root)
+ if exp.is_alive == is_alive:
+ res.append(exp.dict())
+ except:
+ continue
+ return pd.DataFrame(res)
+
+ def interactive(self):
+ """interactive, mark, label, note in ipython environment."""
+ pass
+
+ def server(self):
+ """simple server which make you note your experiments"""
+ pass
+
+ def list_all(self, exp_root=None, limit=100) -> Dict[str, List[Experiment]]:
+ """
+ Returns a dictionary of all experiments under exp_root directory.
+
+ Args:
+ exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory.
+
+ Returns:
+ A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects.
+ """
+
+ def widget(self,
+ is_finished: bool = None,
+ is_alive: bool = None,
+ time_filter: list = None,
+ params_filter: list = None,
+ metric_filter: list = None
+ ):
+ assert params_filter is None or isinstance(params_filter, list)
+ assert metric_filter is None or isinstance(metric_filter, list)
+
+ from ipywidgets import widgets, interact, Label
+ from IPython.display import display
+
+ def make_row(dic: dict):
+ exp = Experiment.from_cache(dic.copy())
+
+ def on_note_update(sender):
+ exp.dump_note(sender['new'])
+
+ def on_tag_update(sender):
+ exp.dump_tags(*sender['new'])
+
+ note_ui = widgets.Textarea(dic['note'])
+
+ note_ui.continuous_update = False
+ note_ui.observe(on_note_update, names='value', type='change')
+
+ tags = dic.get('tags', [])
+ try:
+ tags = list(tags)
+ except:
+ tags = []
+ tag_ui = widgets.TagsInput(value=tags)
+ tag_ui.observe(on_tag_update, names='value', type='change')
+
+ now = datetime.now()
+ start = strptime(datestr=dic['progress']['start'])
+ end = strptime(datestr=dic['progress']['last_edit_time'])
+
+ human = widgets.VBox([
+
+ ])
+ return [
+ widgets.Label(dic['exp_name']),
+ widgets.Label(dic['test_name']),
+ widgets.Label(f"""{strftime('%y-%m-%d %H:%M:%S', dateobj=start)}"""),
+ widgets.Label(f"""{strftime('%y-%m-%d %H:%M:%S', dateobj=end)}"""),
+ widgets.HTML('\n'.join([
+ f'{k}: {v}'
+ for k, v in dic['metrics'].items()
+ if isinstance(v, numbers.Number)
+ ])),
+ widgets.HBox([note_ui,
+ tag_ui, ])
+
+ ]
+
+ test_status = widgets.RadioButtons(options=['full', 'running', 'failed', 'succeed', 'finished'])
+ start_filter = widgets.DatetimePicker()
+ end_filter = widgets.DatetimePicker()
+
+ def status_filter(sender):
+ print(sender)
+ make()
+
+ test_status.observe(status_filter, names='value', type='change')
+
+ # display()
+
+ @interact
+ def make(
+ status=widgets.RadioButtons(options=['full', 'running', 'failed', 'succeed', 'finished']),
+ start=widgets.DatetimePicker(),
+ end=widgets.DatetimePicker(),
+ ):
+ if status == 'running':
+ df = self.progress()
+ elif status == 'finished':
+ df = self.progress(is_alive=False)
+ else:
+ df = self.load()
+ if status == 'succeed':
+ df = df[df['progress'].apply(lambda x: x['finished'])]
+ elif status == 'failed':
+ df = df[df['exception'].isna() == False]
+
+ if start:
+ df = df.pipe(
+ lambda x: x[x['progress'].apply(lambda y: strptime(datestr=y['start'])) > start]
+ )
+ if end:
+ df = df.pipe(
+ lambda x: x[x['progress'].apply(lambda y: strptime(datestr=y['end'])) < end]
+ )
+
+ if params_filter is not None:
+ df_params = df['params']
+ masks = None
+ for condition in params_filter:
+ mask = condition.mask(df_params)
+ if masks is None:
+ masks = mask
+ else:
+ masks *= mask
+ df = df[masks]
+
+ if metric_filter is not None:
+ df_params = df['metrics']
+ masks = None
+ for condition in metric_filter:
+ mask = condition.mask(df_params)
+ if masks is None:
+ masks = mask
+ else:
+ masks *= mask
+ df = df[masks]
+
+ exps = df.to_dict(orient='records')
+ # grid = widgets.GridspecLayout(len(exps) + 1, 7)
+
+ children = [
+ widgets.Label('exp_name'),
+ widgets.Label('test_name'),
+ widgets.Label('start'),
+ widgets.Label('end'),
+ widgets.Label('metrics'),
+ widgets.Label('note & tags'),
+ ]
+ # grid[0, 0] = widgets.Label('Meta')
+ # grid[0, 1] = widgets.Label('Metrics')
+ # grid[0, 2] = widgets.Label('Notes')
+ for i, exp in enumerate(exps, start=1):
+ row = make_row(exp)
+ children.extend(row)
+ # display(widgets.HBox(row))
+ # for j, item in enumerate(row):
+ # grid[i, j] = item
+
+ grid = widgets.GridBox(children=children,
+
+ layout=widgets.Layout(
+ width='100%',
+ grid_template_columns=' '.join(['auto'] * 5) + ' auto',
+ # grid_template_rows='80px auto 80px',
+ grid_gap='5px 10px')
+ )
+ display(
+ widgets.HTML("""
+
+ """),
+ grid,
+
+ )
+
+ # return display(
+ # widgets.HTML(styles['row-radio']),
+ # widgets.HTML("""
+ #
+ # """),
+ # grid, clear=True)
+
+
+class ExperimentWidget:
+ @overload
+ def __init__(self, exp_name, test_name,
+ progress: dict,
+ params: dict, metrics: dict, note: str, tags: set, exp: Experiment):
+ pass
+
+ def __init__(self, **kwargs):
+ from ipywidgets import widgets
+ self.wid = widgets
+ self.exp = kwargs.pop('exp') # type: Experiment
+ self._prop = kwargs
+
+ self._widgets = {
+ 'exp_name': widgets.HTML(self._prop['exp_name']),
+ 'test_name': widgets.HTML(self._prop['test_name']),
+ 'metrics': widgets.VBox(
+ [widgets.HTML(f'{k}: {v}') for k, v in self._prop['metrics'].items() if
+ isinstance(v, numbers.Number)]),
+ }
+
+ self._params_widgets = {}
+
+ note_ui = widgets.Textarea(self._prop['note'])
+
+ note_ui.continuous_update = False
+ note_ui.observe(self.on_note_update, names='value', type='change')
+ self._widgets['note'] = note_ui
+
+ tag_ui = widgets.TagsInput(value=list(self._prop['tags']))
+ self._widgets['tags'] = tag_ui
+ tag_ui.observe(self.on_tag_update, names='value', type='change')
+
+ def on_note_update(self, sender):
+ self.exp.dump_note(sender['new'])
+
+ def on_tag_update(self, sender):
+ self.exp.dump_tags(*sender['new'])
+
+ def set_key_params(self, keys: list):
+ self._params_widgets.clear()
+ for key in keys:
+ self._params_widgets[key] = self.wid.HTML(
+ f"""{key}: {pformat(self._prop['params'][key], width=10, indent=2, compact=True)}""")
+
+ def sep(self):
+ return self.wid.Output(layout={'border': '1px solid black'})
+
+ def id_flag(self):
+ return self.wid.VBox([
+ self._widgets['exp_name'],
+ self._widgets['test_name'],
+ ])
+
+ def key_params(self):
+ return self.wid.VBox([
+ *self._params_widgets.values()
+ ])
+
+ def editable(self):
+ return self.wid.VBox([
+ self._widgets['note'],
+ self.sep(),
+ self._widgets['tags'],
+ ])
+
+ def time(self):
+ now = datetime.now()
+ start = strptime(datestr=self._prop['progress']['start'])
+ end = strptime(datestr=self._prop['progress']['start'])
+ return self.wid.VBox([
+ self.wid.HTML(f"""Start at: {format_timedelta(now - start)}"""),
+ self.wid.HTML(f"""End at: {format_timedelta(now - end)}"""),
+ ])
+
+ def widget_dict(self):
+ return {
+ 'id_flag': self.id_flag(),
+ 'time': self.time(),
+ 'editable': self.editable(),
+ 'params': self.key_params(),
+ }
+
+ def widget(self):
+ params = self.key_params()
+ params = [
+ self.sep(),
+ params,
+ ]
+
+ hbox = self.wid.HBox([
+ self.id_flag(),
+ self.time(),
+ self._widgets['metrics'],
+ self.editable(),
+ self.key_params(),
+ ])
+
+ return hbox
+
+ @classmethod
+ def from_experiment(cls, exp: Experiment):
+ tags = exp.properties.get('tags', [])
+ try:
+ tags = set(tags)
+ except:
+ tags = set()
+ return cls(
+ exp_name=exp.exp_name,
+ test_name=exp.test_name,
+ progress=exp.properties.get('progress', {}),
+ params=exp['params'],
+ metrics=exp.metric.value,
+ note=exp.properties.get('note', ''),
+ tags=tags,
+ exp=exp,
+ )
diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py
index e3b74399..925c4a37 100644
--- a/src/lumo/proc/config.py
+++ b/src/lumo/proc/config.py
@@ -4,7 +4,6 @@
__all__ = ['debug_mode', 'glob', 'global_config_path', 'local_config_path']
import tempfile
-from typing import overload
GLOBAL_DEFAULT = {
'home': os.path.expanduser("~/.lumo/"),
@@ -14,10 +13,25 @@
def global_config_path():
+ """
+ Returns the path to the global configuration file.
+
+ Returns:
+ str: The path to the global configuration file.
+
+ Notes:
+ The relative path of global config path should never change (~/.lumorc.json)
+ """
return os.path.expanduser("~/.lumorc.json")
def local_config_path():
+ """
+ Returns the path to the local configuration file.
+
+ Returns:
+ str: The path to the local configuration file, if found. Otherwise, None.
+ """
from lumo.utils.repository import git_dir
res = git_dir()
if res:
@@ -25,7 +39,34 @@ def local_config_path():
return None
+def local_public_config_path():
+ """
+ Returns the path to the local configuration file that can be shared and public.
+
+ Returns:
+ str: The path to the local configuration file, if found. Otherwise, None.
+ """
+ from lumo.utils.repository import git_dir
+ res = git_dir()
+ if res:
+ return os.path.join(res, ".lumorc.public.json")
+ return None
+
+
def get_config(path, default):
+ """
+ Reads the configuration file at the given path or creates it if it doesn't exist.
+
+ Args:
+ path (str): The path to the configuration file.
+ default (dict): The default configuration to use if the file doesn't exist.
+
+ Returns:
+ dict: The configuration read from the file or the default configuration if the file doesn't exist.
+
+ Raises:
+ Exception: If there was an error reading the configuration file.
+ """
if path is None:
return default
@@ -44,24 +85,50 @@ def get_config(path, default):
def get_runtime_config():
- glob_cfg = get_config(global_config_path(), GLOBAL_DEFAULT)
- local_cfg = get_config(local_config_path(), {})
+ """
+ Returns the runtime configuration by merging the global and local configurations.
+
+ Returns:
+ dict: The merged runtime configuration.
+ """
+ # default
cfg = GLOBAL_DEFAULT
+
+ # global config (~/.lumorc.json)
+ glob_cfg = get_config(global_config_path(), GLOBAL_DEFAULT)
cfg.update(glob_cfg)
+
+ # local private config ({repo}/.lumorc.json)
+ local_cfg = get_config(local_config_path(), {})
cfg.update(local_cfg)
+
+ # local public config ({repo}/.lumorc.public.json)
+ local_public_cfg = get_config(local_public_config_path(), {})
+ cfg.update(local_public_cfg)
return cfg
def debug_mode(base_dir=None, disable_git=True):
+ """Sets up global variables for debugging mode.
+
+ Args:
+ base_dir (str, optional): The directory to create temporary directories in. Defaults to None.
+ disable_git (bool, optional): Whether to disable git hooks. Defaults to True.
+
+ Returns:
+ None
+ """
glob['exp_root'] = tempfile.mkdtemp(dir=base_dir)
+ glob['db_root'] = tempfile.mkdtemp(dir=base_dir)
glob['progress_root'] = tempfile.mkdtemp(dir=base_dir)
+ glob['metric_root'] = tempfile.mkdtemp(dir=base_dir)
+
glob['home'] = tempfile.mkdtemp(dir=base_dir)
glob['cache_dir'] = tempfile.mkdtemp(dir=base_dir)
glob['blob_root'] = tempfile.mkdtemp(dir=base_dir)
- glob['metric_root'] = tempfile.mkdtemp(dir=base_dir)
- glob['HOOK_LOCKFILE'] = False
+ # glob['HOOK_LOCKFILE'] = False
glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp(dir=base_dir)
- glob['HOOK_RECORDABORT'] = False
+ # glob['HOOK_RECORDABORT'] = False
glob['HOOK_TIMEMONITOR'] = False
if disable_git:
diff --git a/src/lumo/proc/dependency.py b/src/lumo/proc/dependency.py
index a67bcc6f..ae0e6328 100644
--- a/src/lumo/proc/dependency.py
+++ b/src/lumo/proc/dependency.py
@@ -1,23 +1,8 @@
import importlib
-from accelerate import __version__ as accelerate_version
-from fire import __version__ as fire_version
-from joblib import __version__ as joblib_version
-from psutil import __version__ as psutil_version
-
-from lumo import __version__ as lumo_version
-
__all__ = ['get_lock']
-class Version:
- # lumo = lumo_version
- joblib = joblib_version
- fire = fire_version
- psutil = psutil_version
- accelerate = accelerate_version
-
-
def get_lock(*others):
"""
Used to record the specific version of the run-time dependencies to ensure reproducibility.
@@ -30,8 +15,6 @@ def get_lock(*others):
"""
res = {}
- res['lumo'] = lumo_version
- res.update({k: v for k, v in Version.__dict__.items() if not k.startswith('__')})
for lib in others:
mod = importlib.import_module(lib)
diff --git a/src/lumo/proc/path.py b/src/lumo/proc/path.py
index acb45f3c..8cff01b1 100644
--- a/src/lumo/proc/path.py
+++ b/src/lumo/proc/path.py
@@ -5,6 +5,16 @@
def home():
+ """
+ Returns the home directory of the current user.
+
+ This function returns the home directory of the current user, which is
+ different from the home directory of the library. If you need to access
+ the library home directory, use `libhome()` instead.
+
+ Returns:
+ str: The home directory of the current user.
+ """
return os.path.expanduser("~")
@@ -31,7 +41,17 @@ def cache_dir():
def libhome():
- """Library home to store configs. Default is `~/.lumo`"""
+ """
+ Returns the library home directory.
+
+ This function returns the library home directory, which is used to store
+ configuration files. By default, the library home directory is `~/.lumo`.
+ If a custom home directory is set in the `home` global variable, that
+ directory will be returned instead.
+
+ Returns:
+ str: The library home directory.
+ """
LIBHOME = glob.get('home', None)
if LIBHOME:
return LIBHOME
@@ -56,7 +76,7 @@ def progressroot():
if PROGRESS_ROOT:
res = PROGRESS_ROOT
else:
- res = os.path.join(libhome(), 'progress')
+ res = os.path.join(cache_dir(), 'progress')
os.makedirs(res, exist_ok=True)
return res
@@ -91,6 +111,16 @@ def metricroot():
return res
+def dbroot():
+ DB_ROOT = glob.get('db_root', None)
+ if DB_ROOT:
+ res = DB_ROOT
+ else:
+ res = os.path.join(libhome(), 'database')
+ os.makedirs(res, exist_ok=True)
+ return res
+
+
def local_dir():
"""
Project root, default is the parent directory of .git.
diff --git a/src/lumo/proc/pid.py b/src/lumo/proc/pid.py
index 174be98a..5427f215 100644
--- a/src/lumo/proc/pid.py
+++ b/src/lumo/proc/pid.py
@@ -1,20 +1,46 @@
-from psutil import Process
-import sys
+"""
+Returns information about the specified process or the current process, and computes its hash value.
+"""
+from psutil import Process, pid_exists
from joblib import hash
import os
def runtime_pid_obj(pid=None):
+ """Returns a dictionary containing information about a process identified by the given PID.
+
+ Args:
+ pid (int, optional): The PID of the process to get information about. If None, uses the PID of the current process. Defaults to None.
+
+ Returns:
+ dict: A dictionary containing the following keys:
+ - pid (int): The process ID.
+ - pname (str): The name of the process.
+ - pstart (float): The process creation time, in seconds since the epoch.
+ - argv (list): A list of command-line arguments passed to the process.
+ """
if pid is None:
pid = os.getpid()
- p = Process(pid)
- obj = {
- "pid": p.pid, "pname": p.name(), 'pstart': p.create_time(), 'argv': p.cmdline()
- }
- return obj
+
+ if pid_exists(pid):
+ p = Process(pid)
+ obj = {
+ "pid": p.pid, "pname": p.name(), 'pstart': p.create_time(), 'argv': p.cmdline()
+ }
+ return obj
+
+ return None
def pid_hash(pid_obj=None):
+ """Computes the hash of a process object.
+
+ Args:
+ pid_obj (dict, optional): A dictionary containing information about a process. If None, uses the information about the current process. Defaults to None.
+
+ Returns:
+ str: The hash of the process object.
+ """
if pid_obj is None:
pid_obj = runtime_pid_obj()
return hash(pid_obj)
diff --git a/src/lumo/sketch/filelock.py b/src/lumo/sketch/filelock.py
deleted file mode 100644
index 3524463c..00000000
--- a/src/lumo/sketch/filelock.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import os
-import warnings
-
-CAN_USE_LOCK = True
-if os.name == 'nt':
- import win32con, win32file, pywintypes
-
- LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
- LOCK_SH = 0 # The default value
- LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
- __overlapped = pywintypes.OVERLAPPED()
-
-
- def lock(file, flags):
- hfile = win32file._get_osfhandle(file.fileno())
- win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)
-
-
- def unlock(file):
- hfile = win32file._get_osfhandle(file.fileno())
- win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
-elif os.name == 'posix':
- import fcntl
-
-
- def lock(file, flags=fcntl.LOCK_EX):
- fcntl.flock(file.fileno(), flags)
-
-
- def unlock(file):
- fcntl.flock(file.fileno(), fcntl.LOCK_UN)
-else:
- CAN_USE_LOCK = False
-
- warnings.warn(f'You are in an unknown platform {os.name}, filelock may cant be used.')
-
-
- def lock(file, flat):
- raise NotImplementedError(f'an UNKNOWN platform {os.name}')
-
-
- def unlock(file):
- fcntl.flock(file.fileno(), fcntl.LOCK_UN)
diff --git a/src/lumo/decorators/map_extract.py b/src/lumo/sketch/map_extract.py
similarity index 100%
rename from src/lumo/decorators/map_extract.py
rename to src/lumo/sketch/map_extract.py
diff --git a/src/lumo/utils/timer.py b/src/lumo/sketch/timer.py
similarity index 82%
rename from src/lumo/utils/timer.py
rename to src/lumo/sketch/timer.py
index e509aa56..ffe0a1b0 100644
--- a/src/lumo/utils/timer.py
+++ b/src/lumo/sketch/timer.py
@@ -6,22 +6,7 @@
import warnings
from collections import OrderedDict
-from .fmt import strftime
-
-
-def format_second(sec: int) -> str:
- """convert seconds from int to string"""
- sec, ms = divmod(sec, 1)
- if sec > 60:
- min, sec = divmod(sec, 60)
- if min > 60:
- hour, min = divmod(min, 60)
- fmt = "{}h{}m{}s".format(hour, min, int(sec))
- else:
- fmt = "{}m{}s".format(min, int(sec))
- else:
- fmt = "{}s".format(int(sec))
- return fmt
+from lumo.utils.fmt import strftime
class Timer:
diff --git a/src/lumo/vis/__init__.py b/src/lumo/sketch/vis/__init__.py
similarity index 100%
rename from src/lumo/vis/__init__.py
rename to src/lumo/sketch/vis/__init__.py
diff --git a/src/lumo/vis/__main__.py b/src/lumo/sketch/vis/__main__.py
similarity index 92%
rename from src/lumo/vis/__main__.py
rename to src/lumo/sketch/vis/__main__.py
index d3f60694..526c02ba 100644
--- a/src/lumo/vis/__main__.py
+++ b/src/lumo/sketch/vis/__main__.py
@@ -34,10 +34,10 @@ def make_test(self, test_root: str):
st.write(finder.format_experiment(exp))
# with st.expander("Visualize Metrics"):
if exp.has_prop('tensorboard_args'):
- tb = exp.get_prop('tensorboard_args')
+ tb = exp.properties.get('tensorboard_args')
metrics = parser.parse_fron_tensorboard(tb['log_dir'])
elif exp.has_prop('logger_args'):
- tb = exp.get_prop('logger_args')
+ tb = exp.properties.get('logger_args')
metrics = parser.parse_from_log(tb['log_dir'])
else:
metrics = {}
@@ -52,10 +52,10 @@ def make_test(self, test_root: str):
k, v = metrics[i + 1]
m.write(k)
m.line_chart(np.array([vv.value for vv in v]))
- # if i + 2 >= len(metrics):
- # break
- # k, v = metrics[i + 2]
- # r.line_chart({'k': np.array([vv.value for vv in v])})
+ # if i + 2 >= len(metrics):
+ # break
+ # k, v = metrics[i + 2]
+ # r.line_chart({'k': np.array([vv.value for vv in v])})
def select_head(self):
left, right = st.columns([1, 3])
diff --git a/src/lumo/vis/parser.py b/src/lumo/sketch/vis/parser.py
similarity index 94%
rename from src/lumo/vis/parser.py
rename to src/lumo/sketch/vis/parser.py
index c549d9cb..b1f94068 100644
--- a/src/lumo/vis/parser.py
+++ b/src/lumo/sketch/vis/parser.py
@@ -30,10 +30,10 @@ def find_metric_fron_test_root(test_root):
exp = Experiment.from_disk(test_root)
if exp.has_prop('tensorboard_args'):
- tb = exp.get_prop('tensorboard_args')
+ tb = exp.properties.get('tensorboard_args')
metrics = parse_fron_tensorboard(tb['log_dir'])
elif exp.has_prop('logger_args'):
- tb = exp.get_prop('logger_args')
+ tb = exp.properties.get('logger_args')
metrics = parse_from_log(tb['log_dir'])
else:
fs = [i for i in os.listdir(exp.test_root)]
diff --git a/src/lumo/vis/parser_tb.py b/src/lumo/sketch/vis/parser_tb.py
similarity index 100%
rename from src/lumo/vis/parser_tb.py
rename to src/lumo/sketch/vis/parser_tb.py
diff --git a/src/lumo/exp/agent.py b/src/lumo/sketch/wait_pid_stop.py
similarity index 100%
rename from src/lumo/exp/agent.py
rename to src/lumo/sketch/wait_pid_stop.py
diff --git a/src/lumo/trainer/base.py b/src/lumo/trainer/base.py
index cac76624..bab315fe 100644
--- a/src/lumo/trainer/base.py
+++ b/src/lumo/trainer/base.py
@@ -1,15 +1,4 @@
"""
- - 提供主要的训练流
- - 加载数据集 DataLoader
- - train、test、eval
- - 提供对训练流的控制、回调
- - callbacks
- - 提供对训练中间的状态保存
- - metric: Meter
- - checkpoints: Saver
-
-
-trainer = Trainer()
"""
import inspect
import os
@@ -26,6 +15,25 @@
def _exit_hook_v0(exc_type, exc, tb, *args):
+ """Prints the traceback information when an unhandled exception occurs.
+
+ Args:
+ exc_type (type): The type of the exception.
+ exc (Exception): The instance of the exception.
+ tb (traceback): The traceback object containing the call stack.
+ *args: Additional arguments to be passed to the function.
+
+ Returns:
+ None
+
+ Raises:
+ None
+
+ This function is designed to be used as an exit hook with the `sys.excepthook` function.
+ It formats the traceback information and removes any lines related to the `_newfunc` function.
+ The resulting traceback is printed to the `sys.stderr` stream.
+
+ """
import traceback
res = traceback.format_exception(exc_type, exc, tb)
res = [i for i in res if 'in _newfunc' not in i]
@@ -33,6 +41,25 @@ def _exit_hook_v0(exc_type, exc, tb, *args):
def _exit_hook(exc_type, exc, tb, *args):
+ """Prints an error traceback and displays it using the rich library.
+
+ Args:
+ exc_type: Type of the exception that was raised.
+ exc: The exception instance that was raised.
+ tb: Traceback object that contains information about the exception.
+ *args: Optional additional arguments to be passed to _exit_hook_v0.
+
+ Returns:
+ None.
+
+ Raises:
+ Any exceptions that were not caught by _exit_hook_v0.
+
+ Examples:
+ # Call _exit_hook with an exception
+ >>> _exit_hook(TypeError, 'test error', traceback, arg1, arg2)
+
+ """
from rich.console import Console
console = Console()
_exit_hook_v0(exc_type, exc, tb, *args)
@@ -54,40 +81,6 @@ def _exit_hook(exc_type, exc, tb, *args):
console.print(traceback)
-def wrapper(self, func, _call_set: list):
- """
- 对每个 Trainer 的 _call_backs 类变量中定义的函数尝试绑定回调
- Args:
- func:
- _call_set:
-
- Returns:
-
- """
-
- @wraps(func)
- def _newfunc(*aargs, **kkwargs):
- """执行前回调 on_begin() 、执行后回调 on_end()、执行异常则回调 on_exception() """
- for callback in _call_set:
- callback.on_begin(self, func, self.params, *aargs, **kkwargs)
- try:
- _meter = func(*aargs, **kkwargs)
- except BaseException as e:
- _handles = [callback.on_exception(self, func, self.params, e, *aargs, **kkwargs)
- for callback in _call_set]
-
- if any(_handles):
- return None
- else:
- raise e
-
- for callback in _call_set:
- callback.on_end(self, func, self.params, _meter, *aargs, **kkwargs)
- return _meter
-
- return _newfunc
-
-
init_function = ['icallbacks', 'imodels']
call_dependency = {
'train': init_function,
@@ -95,6 +88,8 @@ def _newfunc(*aargs, **kkwargs):
class _BaseTrainer:
+ """Base class for training neural network models.
+ """
__exp_name__ = None
callback_function = {}
@@ -122,8 +117,17 @@ def __new__(cls, *args, **kwargs):
replace(_exit_hook)
def init_wrapper(func):
+ """
+ Wraps the train/test/eval functions to initialize in silence.
+
+ Notes:
+ Before calling the train/test/eval functions, the `trainer.initialize` method is called,
+ and then the corresponding DataLoader for the stage is initialized through the `process_loader` method.
+ """
+
@wraps(func)
def inner(dm=None, params=None, *args, **kwargs):
+ """The inner function that wraps the train/test/eval function."""
init_fn = getattr(self, 'initialize', None)
if init_fn is not None:
init_fn()
@@ -136,18 +140,21 @@ def inner(dm=None, params=None, *args, **kwargs):
def cb_wrapper(func, call_set: list):
"""
- 对每个 Trainer 的 _call_backs 类变量中定义的函数尝试绑定回调
+ Wraps the given function with callback functions.
+
Args:
- func:
- call_set:
+ func (function): The function to wrap.
+ call_set (list): A list of callback functions.
Returns:
-
+ A wrapped function.
"""
@wraps(func)
def _newfunc(*aargs, **kkwargs):
- """执行前回调 on_begin() 、执行后回调 on_end()、执行异常则回调 on_exception() """
+ """
+ Executes the callback functions before and after the given function and on exception.
+ """
# on_begin
self._contexts.append(func.__name__)
for callback in call_set:
@@ -214,7 +221,7 @@ def __setattr__(self, name, value):
elif callable(getattr(value, "state_dict", None)) and callable(getattr(value, "load_state_dict", None)):
type_name = 'others'
else:
- super().__setattr__(name, value)
+ # super().__setattr__(name, value)
return
# if name in self.__dict__: TODO workaround multi-gpu error: Expected to mark a variable ready only once
@@ -225,10 +232,14 @@ def __setattr__(self, name, value):
@property
def contexts(self):
+ """Get the name stack of function call contexts.
+ The first is the name of the Trainer class
+ """
return self._contexts
@property
def context(self):
+ """Get the name of the most recent function call context."""
return self._contexts[-1]
# def __getattr__(self, name):
@@ -245,11 +256,22 @@ def context(self):
@classmethod
def dirname(cls):
+ """Get the directory name of the file where the class is defined.
+
+ Returns:
+ A string representing the directory name of the file where the class is defined.
+ """
file = inspect.getfile(cls)
return os.path.basename(os.path.dirname(file))
@classmethod
def filebasename(cls):
+ """Get the basename of the file where the class is defined.
+
+ Returns:
+ A string representing the basename of the file where the class is defined.
+ If an exception occurs, returns 'builtin'.
+ """
try:
file = inspect.getfile(cls)
pre = os.path.splitext(os.path.basename(file))[0]
@@ -259,6 +281,12 @@ def filebasename(cls):
@classmethod
def generate_exp_name(cls) -> str:
+ """Generate an experiment name based on the file basename and the class name.
+
+ Returns:
+ A string representing the experiment name, formatted as '.'.
+ If '__exp_name__' is defined, it is used instead of the default class name with 'trainer' replaced by 'exp'.
+ """
pre = cls.filebasename()
exp_name = cls.__exp_name__
@@ -268,4 +296,5 @@ def generate_exp_name(cls) -> str:
return "{}.{}".format(pre.lower(), exp_name.lower())
def on_trainer_exception(self, func: Callable, exception: BaseException):
+ """Called when an exception occurs during training."""
pass
diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py
index e6c7a0b8..2f71b228 100644
--- a/src/lumo/trainer/callbacks.py
+++ b/src/lumo/trainer/callbacks.py
@@ -1,7 +1,6 @@
"""
"""
import inspect
-import json
import os
import tempfile
import time
@@ -10,13 +9,15 @@
from datetime import datetime
from functools import wraps
from typing import NewType, Any, Optional, Dict, Union
-from lumo.utils.memory_grab import DeviceMem
+
+import psutil
from torch.utils.data import DataLoader
from lumo.core import Meter, MetricType, Record, TrainStage, wrap_result, ParamsType
from lumo.data import DataModule
from lumo.data.loader import summarize_loader, DataLoaderType
from lumo.utils import fmt
+from lumo.utils.memory_grab import DeviceMem
from lumo.utils.screen import inlinetqdm
from .trainer import Trainer
from ..proc.dist import world_size
@@ -342,6 +343,7 @@ def on_end(self, source: Trainer, func, params: ParamsType, result, *args, **kwa
class LoggerCallback(TrainCallback, InitialCallback):
+ """A callback for logging the training process."""
priority = 99999
def __init__(self, step_frequence=3, break_in=1000):
@@ -385,13 +387,14 @@ def on_train_begin(self, trainer: Trainer, func, params: ParamsType, *args, **kw
file=self.temp)
def renew(self, stage):
- """创建一个新的"""
+ """Renew when change stage(train/eval/test)"""
self.cur_tqdm = inlinetqdm(total=self.stage[stage], position=0, leave=True,
bar_format='{desc}{elapsed}<{remaining} ({percentage:3.0f}%){postfix}',
file=self.temp)
self.record = Record()
def update(self, trainer: Trainer):
+ """Update"""
self.c += 1
self.cur_tqdm.update()
if self.c % self.step == 0:
@@ -399,10 +402,11 @@ def update(self, trainer: Trainer):
if self.c % self.breakin == 0 or (
TrainStage.train in self.stage and ((trainer.idx + 1) == self.stage[TrainStage.train])):
- trainer.logger.info(self.cur_tqdm.full_str())
- # trainer.logger.newline()
+ trainer.logger.inline(self.cur_tqdm.full_str())
+ trainer.logger.newline()
def flush(self, trainer: Trainer):
+ """Flush"""
self.c = 0
trainer.logger.inline(self.cur_tqdm)
trainer.logger.newline()
@@ -447,6 +451,7 @@ def format_interval(t: float):
def format_train_epoch_time(self, n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it',
unit_scale=False, rate=None, bar_format=None, postfix=None,
unit_divisor=1000, initial=0, colour=None, **extra_kwargs):
+ """Format"""
elapsed_str = self.format_interval(elapsed)
remaining = (total - n) / rate if rate and total else 0
remaining_str = self.format_interval(remaining) if rate else '?'
@@ -508,61 +513,6 @@ def on_test_step_end(self, trainer: Trainer, func, params: ParamsType,
self.update(trainer)
-class EpochCheckpoint(TrainCallback):
- """
- 在 Trainer 训练过程中定时保存模型
- """
- only_main_process = True
-
- def __init__(self, per_epoch=50):
- self.per_epoch = per_epoch
-
- def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Optional[Record], *args,
- **kwargs):
- meter = record.avg()
- if trainer.eidx % self.per_epoch == 0 and trainer.eidx > 0:
- trainer.save_checkpoint(meta_info=Meter.wrap_result(meter))
-
- def __repr__(self) -> str:
- return self._repr_by_val("per_epoch")
-
-
-class GlobalStepCheckpoint(TrainCallback):
- only_main_process = True
-
- def __init__(self, per_step=2500):
- self.per = per_step
-
- def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: Meter, *args, **kwargs):
- super().on_train_step_end(trainer, func, params, metric, *args, **kwargs)
- if trainer.global_steps % self.per == 0 and trainer.global_steps > 0:
- trainer.save_checkpoint(meta_info=Meter.wrap_result(metric))
-
-
-class KeyErrorSave(TrainCallback):
- """
- Callback to save checkpoints when you interrupt the program.
- """
- only_main_process = True
- only_single_gpu = True
- priority = -1
-
- def __init__(self, wait_input=False):
- self.wait_input = wait_input
-
- def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseException, *args, **kwargs):
- if isinstance(e, KeyboardInterrupt):
- source.logger.info("KeyErrorSave trigged, save checkpoint")
- source.save_checkpoint({"mode": "KeyboardInterrupt"})
-
- tp = "n"
- if self.wait_input:
- tp = input("continue train step? (y/other)")
-
- if tp.lower() == "y":
- return True
-
-
class EMAUpdate(TrainCallback):
"""
Callback to update EMA model every train step.
@@ -572,7 +522,7 @@ class EMAUpdate(TrainCallback):
- name is started with 'ema'
"""
- def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType, *args, **kwargs):
+ def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs):
super().on_train_step_end(trainer, func, params, metric, *args, **kwargs)
for k, v in trainer.model_dict.items():
if k.lower().startswith('ema'):
@@ -656,58 +606,23 @@ def log_matrix(self, metrics: Dict, step: int, namespace: str):
def on_hooked(self, source: Trainer, params: ParamsType):
super().on_hooked(source, params)
- source.exp.set_prop('AutoRecord', self.__class__.__name__)
+ source.exp.dump_string('AutoRecord', self.__class__.__name__)
- def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType, *args, **kwargs):
+ def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs):
super().on_train_step_end(trainer, func, params, metric, *args, **kwargs)
self.log(metric, step=trainer.global_steps, namespace='train.step')
def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs):
super().on_train_epoch_end(trainer, func, params, record, *args, **kwargs)
- self.log(record.avg(), step=trainer.global_steps, namespace='train.epoch')
+ self.log(record.agg(), step=trainer.global_steps, namespace='train.epoch')
- def on_test_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs):
+ def on_test_end(self, trainer: Trainer, func, params: ParamsType, record: Record = None, *args, **kwargs):
super().on_test_end(trainer, func, params, record, *args, **kwargs)
- self.log(record.avg(), step=trainer.global_steps, namespace='test')
+ self.log(record.agg(), step=trainer.global_steps, namespace='test')
- def on_eval_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs):
+ def on_eval_end(self, trainer: Trainer, func, params: ParamsType, record: Record = None, *args, **kwargs):
super().on_eval_end(trainer, func, params, record, *args, **kwargs)
- self.log(record.avg(), step=trainer.global_steps, namespace='evaluate')
-
-
-class WandbCallback(RecordCallback):
- only_main_process = True
-
- def __init__(self, metric_step=500) -> None:
- super().__init__()
- self.metric_step = metric_step
- self.c = 0
-
- def log(self, metrics: MetricType, step, namespace):
- self.c += 1
- if self.c % self.metric_step == 0:
- metrics = {
- f"{namespace}.{k}": v
- for k, v in wrap_result(metrics).items()}
- self._hooked.wandb.log(metrics, step=step)
-
- def log_text(self, metrics: Dict, step: int, namespace: str):
- wandb = self._hooked.wandb
- metrics = {k: v for k, v in metrics.items()}
- wandb.log(metrics, step=step)
-
- def log_scalars(self, metrics: Dict, step: int, namespace: str):
- wandb = self._hooked.wandb
- metrics = {k: wandb.Html(v) for k, v in metrics.items()}
- wandb.log(metrics, step=step)
-
- def log_matrix(self, metrics: Dict, step: int, namespace: str):
- wandb = self._hooked.wandb
- metrics = {k: wandb.Image(v) for k, v in metrics.items()}
- wandb.log(metrics, step=step)
-
- def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseException, *args, **kwargs):
- super().on_first_exception(source, func, params, e, *args, **kwargs)
+ self.log(record.agg(), step=trainer.global_steps, namespace='evaluate')
class TensorBoardCallback(RecordCallback):
@@ -740,7 +655,8 @@ class StopByCode(TrainCallback):
def __init__(self, step=100):
self.step = step
- def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: Meter, *args, **kwargs):
+ def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs):
+ super().on_train_step_end(trainer, func, params, metric, *args, **kwargs)
if trainer.global_steps % self.step == 0:
if os.path.exists(trainer.exp.test_file('.stop')):
trainer.exp.add_tag('lumo.early_stop')
@@ -857,7 +773,7 @@ def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseE
}))
-class CUDAMemoryRecord(TrainCallback):
+class ResourceRecord(TrainCallback):
"""
Record CUDA GPU maximum memory used during training.
@@ -868,16 +784,20 @@ class CUDAMemoryRecord(TrainCallback):
def on_hooked(self, source: Trainer, params: ParamsType):
super().on_hooked(source, params)
- self.max_memory = 0
self.device = source.device
self.pid = os.getpid()
self.mem = DeviceMem()
def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs):
super().on_train_epoch_end(trainer, func, params, record, *args, **kwargs)
- self.max_memory = max(self.max_memory, self.mem.get_pid_device_mem(self.pid, self.device))
- trainer.database.update('memory', self.max_memory)
- trainer.exp.dump_info('max_memory', self.max_memory)
+ trainer.exp.dump_metric('CUDA_memory', self.mem.get_pid_device_mem(self.pid, self.device), cmp='max')
+
+ # 获取进程的内存信息
+ memory_info = psutil.Process(self.pid).memory_info()
+
+ # 打印内存信息
+ trainer.exp.dump_metric('CPU_memory_rss', memory_info.rss / 1024 / 1024)
+ trainer.exp.dump_metric('CPU_memory_vms', memory_info.vms / 1024 / 1024)
class SkipWhenParamsEq(TrainCallback, InitialCallback):
@@ -886,7 +806,7 @@ def on_hooked(self, source: Trainer, params: ParamsType):
super().on_hooked(source, params)
from dbrecord import PDict
from lumo.exp.finder import is_test_root
- self.fn = source.exp.exp_file('params_key.sqlite')
+ self.fn = source.exp.mk_rpath('contrib', 'params_key.sqlite')
olds = PDict(self.fn)
current = source.params.hash()
@@ -896,7 +816,7 @@ def on_hooked(self, source: Trainer, params: ParamsType):
if isinstance(old, str) and is_test_root(old) and os.path.exists(old):
source.stop_train()
source.stop_train_epoch()
- source.database.update('skiped', True, flush=True)
+ source.exp.dump_info('early_stop', True)
source.logger.info(
f'Find finished test with equal params ({current}) from {olds[current]}. '
f'To runStored in:\n {self.fn}')
@@ -905,6 +825,6 @@ def on_train_end(self, trainer: Trainer, func, params: ParamsType, record: Recor
super().on_train_end(trainer, func, params, record, *args, **kwargs)
from dbrecord import PDict
olds = PDict(self.fn)
- olds[params.hash()] = trainer.exp.test_root
+ olds[params.hash()] = trainer.exp.info_dir
olds.flush()
trainer.logger.info(f'Save current params ({params.hash()}) to {self.fn}')
diff --git a/src/lumo/trainer/components.py b/src/lumo/trainer/components.py
index 17ff9e98..37d5bf7e 100644
--- a/src/lumo/trainer/components.py
+++ b/src/lumo/trainer/components.py
@@ -1,5 +1,3 @@
-from typing import NewType
-
import torch
from lumo.core import Params
@@ -8,14 +6,15 @@
class TrainerExperiment(SimpleExperiment):
+ """A class for helping manage an experiment by Trainer."""
@property
def log_dir(self):
- return self.test_root
+ return self.info_dir
@property
def params_fn(self):
- res = self.test_file('params.yaml')
+ res = self.mk_ipath('params.yaml')
self.dump_string('params.yaml', res)
return res
@@ -25,7 +24,7 @@ def board_args(self):
if self.has_prop(key):
return self.get_prop(key)
else:
- log_dir = self.test_dir('board')
+ log_dir = self.mk_ipath('board', is_dir=True)
res = {
'filename_suffix': '.bd',
'log_dir': log_dir,
@@ -35,23 +34,24 @@ def board_args(self):
@property
def state_dict_dir(self):
- res = self.blob_dir('state_dict')
+ res = self.mk_bpath('state_dict', is_dir=True)
return res
def dump_train_eidx(self, eidx, epoch: int):
"""
- Args:
- eidx: start from 0, end at `epoch-1`
- epoch:
+ Dumps the progress of the trainer.
+
+ Args:
+ eidx (int): The index of the current epoch (starting from 0).
+ epoch (int): The total number of epochs to train for.
"""
self.dump_progress((eidx + 1) / epoch, update_from='trainer')
-class ReimplementExperiment(TrainerExperiment):
- pass
-
-
class TrainerParams(Params):
+ """
+ A class to hold parameters for trainer.
+ """
OPTIM = OptimFactory
SCHE = INTERP = InterpFactory
diff --git a/src/lumo/trainer/factory.py b/src/lumo/trainer/factory.py
index e32cc6dd..939f057f 100644
--- a/src/lumo/trainer/factory.py
+++ b/src/lumo/trainer/factory.py
@@ -8,6 +8,24 @@
class InterpFactory:
+ """A factory class for creating instances of various interpolation classes.
+
+ This class provides convenient access to various interpolation classes that are defined in the `interp` module.
+
+ Attributes:
+ Cos (class): An interpolation class for cosine interpolation.
+ Linear (class): An interpolation class for linear interpolation.
+ Exp (class): An interpolation class for exponential interpolation.
+ Log (class): An interpolation class for logarithmic interpolation.
+ Constant (class): An interpolation class for constant interpolation.
+ PeriodCos (class): An interpolation class for periodic cosine interpolation.
+ PeriodHalfCos (class): An interpolation class for periodic half-cosine interpolation.
+ PeriodTriangle (class): An interpolation class for periodic triangle interpolation.
+ PeriodLinear (class): An interpolation class for periodic linear interpolation.
+ PowerDecay (class): An interpolation class for power-decay interpolation.
+ List (class): An interpolation class for list interpolation.
+
+ """
Cos = interp.Cos
Linear = interp.Linear
Exp = interp.Exp
@@ -22,14 +40,44 @@ class InterpFactory:
class OptimBuilder(BaseParams):
+ """A class for building an optimizer with specified parameters.
+
+ Attributes:
+ None
+
+ Methods:
+ from_kwargs(cls, **kwargs): Creates a new instance of OptimBuilder class and updates its attributes with the given keyword arguments.
+ build(self, parameters, optim_cls=None) -> Optimizer: Builds and returns an optimizer with the specified parameters.
+
+ """
@classmethod
def from_kwargs(cls, **kwargs):
+ """Creates a new instance of OptimBuilder class and updates its attributes with the given keyword arguments.
+
+ Args:
+ **kwargs: A dictionary containing the optimizer parameters.
+
+ Returns:
+ self: A new instance of the OptimBuilder class with updated attributes.
+ """
self = cls()
self.update(kwargs)
return self
def build(self, parameters, optim_cls=None) -> Optimizer:
+ """Builds and returns an optimizer with the specified parameters.
+
+ Args:
+ parameters: The parameters for the optimizer.
+ optim_cls: The class of the optimizer to be built.
+
+ Returns:
+ optim_cls: The built optimizer.
+
+ Raises:
+ ModuleNotFoundError: If the specified optimizer class cannot be found in the corresponding module.
+ """
res = self.copy()
name = res['name']
lname = name.lower()
@@ -42,64 +90,83 @@ def build(self, parameters, optim_cls=None) -> Optimizer:
else:
optim_lib = importlib.import_module("torch.optim.{}".format(lname))
optim_cls = getattr(optim_lib, name, None)
+ if optim_cls is None:
+ raise ModuleNotFoundError("Cannot find {} in {}".format(name, optim_lib))
return optim_cls(parameters, **args)
class _OptimFactory:
+ """
+ A factory class that provides different optimization algorithms to be used during training.
+
+ Methods:
+ create_optim(name=None, **kwargs) -> OptimBuilder:
+ Creates an instance of OptimBuilder for a specified optimization algorithm.
+
+ Examples:
+ To create an instance of OptimBuilder for Adam optimizer with default values:
+ >>> optim_builder = OptimFactory.create_optim(name='Adam')
+
+ To create an instance of OptimBuilder for SGD optimizer with specific values:
+ >>> optim_builder = OptimFactory.create_optim(name='SGD', lr=0.01, momentum=0.9)
+
+ """
+
@overload
def create_optim(self, name='SGD', lr=None, momentum=0, dampening=0, weight_decay=0,
nesterov=False) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the SGD optimizer."""
@overload
def create_optim(self, name='Adam', lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
amsgrad=False) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the Adam optimizer."""
@overload
def create_optim(self, name='Adadelta', lr=1.0, rho=0.9, eps=1e-6, weight_decay=0) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the Adadelta optimizer."""
@overload
def create_optim(self, name='Adagrad', lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0,
eps=1e-10) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the Adagrad optimizer."""
@overload
def create_optim(self, name='AdamW', lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the AdamW optimizer."""
@overload
def create_optim(self, name='AdamW',
lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the AdamW optimizer."""
@overload
def create_optim(self, name='ASGD',
lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the ASGD optimizer."""
@overload
def create_optim(self, name='LBFGS',
lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=100,
line_search_fn=None) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the LBFGS optimizer."""
@overload
def create_optim(self, name='RMSprop', lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0,
centered=False) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the RMSprop optimizer."""
@overload
def create_optim(self, name='Rprop', lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the Rprop optimizer."""
@overload
def create_optim(self, name='SparseAdam', lr=1e-3, betas=(0.9, 0.999), eps=1e-8) -> OptimBuilder:
- pass
+ """Creates an instance of OptimBuilder for the SparseAdam optimizer."""
def create_optim(self, name=None, **kwargs) -> OptimBuilder:
+ """Create."""
return OptimBuilder.from_kwargs(name=name, **kwargs)
diff --git a/src/lumo/trainer/rnd.py b/src/lumo/trainer/rnd.py
index b7509c28..6c685aca 100644
--- a/src/lumo/trainer/rnd.py
+++ b/src/lumo/trainer/rnd.py
@@ -1,41 +1,29 @@
-import os
-import time
from typing import Union
-from joblib import hash
-from lumo.proc.path import cache_dir
from lumo.utils import random
class RndManager:
"""
- A seed manager for trainer. Provide interface for `~lumo.utils.random`
+ A seed manager for the trainer. Provides an interface for `~lumo.utils.random`.
"""
- def __init__(self):
- self.save_dir = os.path.join(cache_dir(), 'rnd')
-
def mark(self, seed: Union[int, str]):
"""
- 用于数据集读取一类的,需要特定步骤每一次试验完全相同
- Args:
- seed: 该次标记固定种子的名字,第一次调用该方法会在特定目录存放当前状态,
- 第二次调用会在该位置读取当前随机种子状态
-
- Returns:
+ Fixes the random seed to a specific state for reproducibility.
+ Args:
+ seed (Union[int, str]): The name of the fixed seed state.
"""
random.fix_seed(random.hashseed(seed))
def shuffle(self, seed=None):
"""
- 打乱,一般用于复现试验的时候随机一个种子
- Args:
- name:
- seed:
-
- Returns:
+ Shuffles the random seed for reproducibility.
+ Args:
+ seed (int, optional): The random seed to use. If None, a random seed based on the current
+ time will be used.
"""
if seed is None:
random.fix_seed(random.int_time())
diff --git a/src/lumo/trainer/saver.py b/src/lumo/trainer/saver.py
index e1215e93..1088e03c 100644
--- a/src/lumo/trainer/saver.py
+++ b/src/lumo/trainer/saver.py
@@ -9,17 +9,68 @@
# state_dict_tuple = namedtuple('state_dict_tuplt', ['state_dict', 'meta_info'], defaults=[None])
class state_dict_tuple:
+ """
+ A class that stores a state dictionary and corresponding meta information.
+
+ Args:
+ state_dict (dict, optional): The state dictionary to be stored. Defaults to None.
+ meta_info (Any, optional): Any meta information to be stored. Defaults to None.
+
+ Attributes:
+ state_dict (dict): The stored state dictionary.
+ meta_info (Any): The stored meta information.
+
+ Returns:
+ A state_dict_tuple instance.
+
+ Raises:
+ IndexError: If the index is not 0 or 1.
+
+ Examples:
+ # Create an instance of state_dict_tuple with state_dict and meta_info
+ >>> sd = state_dict_tuple({'a': 1, 'b': 2}, 'meta')
+
+ # Access the state_dict and meta_info using the [] operator
+ >>> sd[0]
+ {'a': 1, 'b': 2}
+ >>> sd[1]
+ 'meta'
+ """
+
def __init__(self, state_dict=None, meta_info=None):
self.state_dict = state_dict
self.meta_info = meta_info
def __getitem__(self, item):
+ """
+ Get the stored state dictionary or meta information.
+
+ Args:
+ item (int): Index of the desired item.
+
+ Returns:
+ The state dictionary (if item is 0) or the meta information (if item is 1).
+
+ Raises:
+ IndexError: If the index is not 0 or 1.
+
+ Examples:
+ # Access the state_dict and meta_info using the [] operator
+ >>> sd = state_dict_tuple({'a': 1, 'b': 2}, 'meta')
+ >>> sd[0]
+ {'a': 1, 'b': 2}
+ >>> sd[1]
+ 'meta'
+ >>> sd[2] # Raises IndexError
+ IndexError: 2
+ """
if item == 0:
return self.state_dict
elif item == 1:
return self.meta_info
raise IndexError(item)
+
class Saver:
"""
Write state_dict into test dirs, record save log into /.lumo/save..log
@@ -86,13 +137,13 @@ def dump_state_dict(self, obj, fn, meta_info: Union[str, dict] = None):
Returns:
saved filepath, None if something went wrong.
"""
- res = io.dump_state_dict(obj, fn)
- if res and meta_info is not None:
+ io.dump_state_dict(obj, fn)
+ if meta_info is not None:
if isinstance(meta_info, str):
meta_info = {'msg': meta_info}
json_fn = f"{fn}.json"
io.dump_json(meta_info, json_fn)
- return res
+ return fn
def load_state_dict(self, fn: str, with_meta=False, map_location='cpu') -> state_dict_tuple:
"""
@@ -225,6 +276,20 @@ def save_model(self, step: int, state_dict, meta_info: Union[str, dict] = None,
return None
def load_checkpoint(self, index=-1, best_if_exist=False, fn=None, with_meta=False, map_location='cpu'):
+ """
+ Loads a checkpoint file.
+
+ Args:
+ index (int, optional): Index of the checkpoint file in the list of checkpoints. Defaults to -1.
+ best_if_exist (bool, optional): If True, load the best checkpoint file. Defaults to False.
+ fn (str, optional): The filename of the checkpoint file. Defaults to None.
+ with_meta (bool, optional): If True, the checkpoint file is expected to contain metadata. Defaults to False.
+ map_location (str, optional): Where to load the checkpoint file. Defaults to 'cpu'.
+
+ Returns:
+ Union[None, Any]: None if the checkpoint file could not be loaded, otherwise the loaded checkpoint.
+
+ """
if fn is None and best_if_exist:
fn = self.best_checkpoint()
if fn is None:
@@ -237,6 +302,19 @@ def load_checkpoint(self, index=-1, best_if_exist=False, fn=None, with_meta=Fals
return None
def load_keypoint(self, index=-1, fn=None, with_meta=False, map_location='cpu'):
+ """
+ Loads a checkpoint file that is key.
+
+ Args:
+ index (int, optional): Index of the keypoint file in the list of keypoints. Defaults to -1.
+ fn (str, optional): The filename of the keypoint file. Defaults to None.
+ with_meta (bool, optional): If True, the keypoint file is expected to contain metadata. Defaults to False.
+ map_location (str, optional): Where to load the keypoint file. Defaults to 'cpu'.
+
+ Returns:
+ Union[None, Any]: None if the keypoint file could not be loaded, otherwise the loaded keypoint.
+
+ """
if fn is None:
try:
fn = self.list_keypoints()[index]
@@ -247,6 +325,20 @@ def load_keypoint(self, index=-1, fn=None, with_meta=False, map_location='cpu'):
return None
def load_model(self, index=-1, best_if_exist=False, fn=None, with_meta=False, map_location='cpu'):
+ """
+ Loads a model file.
+
+ Args:
+ index (int, optional): Index of the model file in the list of models. Defaults to -1.
+ best_if_exist (bool, optional): If True, load the best model file. Defaults to False.
+ fn (str, optional): The filename of the model file. Defaults to None.
+ with_meta (bool, optional): If True, the model file is expected to contain metadata. Defaults to False.
+ map_location (str, optional): Where to load the model file. Defaults to 'cpu'.
+
+ Returns:
+ Union[None, Any]: None if the model file could not be loaded, otherwise the loaded model.
+
+ """
if fn is None and best_if_exist:
fn = self.best_model()
if fn is None:
@@ -259,38 +351,52 @@ def load_model(self, index=-1, best_if_exist=False, fn=None, with_meta=False, ma
return None
def best_checkpoint(self):
+ """
+ Returns the filename of the best checkpoint file if it exists, otherwise None.
+
+ Returns:
+ Union[None, str]: The filename of the best checkpoint file if it exists, otherwise None.
+
+ """
fn = os.path.join(self.save_dir, 'best.checkpoint.pt')
if os.path.exists(fn):
return fn
return None
def best_model(self):
+ """Returns the filename of the best model file if it exists, otherwise None."""
fn = os.path.join(self.save_dir, 'best.model.pt')
if os.path.exists(fn):
return fn
return None
def _is_pkl(self, x: str, start, end):
+ """Helper function that returns True if the string x starts with start and ends with end."""
return x.startswith(start) and x.endswith(end)
def list_checkpoints(self) -> List[str]:
+ """Returns a sorted list of filenames of all checkpoint files in the save directory."""
return sorted(list(filter(lambda x: self._is_pkl(x, 'checkpoints', 'pt'),
os.listdir(self.save_dir))),
key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime)
def list_keypoints(self) -> List[str]:
+ """Returns a sorted list of filenames of all keypoint files in the save directory."""
return sorted(list(filter(lambda x: self._is_pkl(x, 'key', 'pt'),
os.listdir(self.save_dir))),
key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime)
def list_models(self) -> List[str]:
+ """Returns a sorted list of filenames of all model files in the save directory."""
return sorted(list(filter(lambda x: self._is_pkl(x, 'model', 'pt'),
os.listdir(self.save_dir))),
key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime)
def list(self):
+ """Returns a sorted list of all filenames in the save directory."""
return sorted(os.listdir(self.save_dir))
def summary(self):
+ """Prints a summary of the contents of the save directory."""
print(f"saved in : {self.save_dir}")
print(textwrap.indent('\n'.join(self.list()), ' - '))
diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py
index c465ad1c..27338f92 100644
--- a/src/lumo/trainer/trainer.py
+++ b/src/lumo/trainer/trainer.py
@@ -13,7 +13,7 @@
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
-
+import json
from lumo.contrib.accelerate import Accelerator
from lumo.contrib.accelerate.utils import send_to_device
from lumo.core import TrainStage, Record, MetricType, Meter
@@ -24,7 +24,6 @@
from lumo.proc import glob
from lumo.trainer.rnd import RndManager
from lumo.utils.logger import Logger
-from lumo.utils.fmt import strftime
from .base import _BaseTrainer
from .components import TrainerExperiment, TrainerParams
from .saver import Saver
@@ -32,6 +31,8 @@
# overwrite send_to_device to resolve https://github.com/pytorch/pytorch/issues/83015
# from accelerate import Accelerator
# from accelerate.utils import send_to_device
+from ..utils.fmt import strftime
+
ParamsType = TrainerParams
@@ -65,6 +66,12 @@ class Trainer(_BaseTrainer):
'process_loader', 'regist_dataloader'
}
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ if cls.__name__ == 'Trainer':
+ raise TypeError(
+ f"Can't instantiate abstract class {cls.__name__} directly, please create a subclass of it.")
+
def __init__(self, params: ParamsType, dm: DataModule = None):
if dm is None:
dm = DataModule(params)
@@ -80,10 +87,12 @@ def __init__(self, params: ParamsType, dm: DataModule = None):
self.params.iparams()
self.exp = TrainerExperiment(self.generate_exp_name())
- self.database = TableRow(self.exp.project_name, self.exp.exp_name, self.exp.test_name_with_dist)
- self.metric_board = Metrics(self.exp.test_root)
+ self._database = TableRow(self.exp.mk_ipath('metric.pkl'), persistent=self.is_main)
+ self.metric_board = Metrics(self.exp.mk_bpath('board.sqlite'), persistent=self.is_main)
+ self.metric = self.exp.metric
+
self.exp.dump_info('metric_board', self.metric_board.fpath)
- self.exp.dump_info('table_row', self.database.fpath)
+ self.exp.dump_info('table_row', self._database.fpath)
self.rnd = RndManager()
self.train_epoch_toggle = False
@@ -100,12 +109,13 @@ def __init__(self, params: ParamsType, dm: DataModule = None):
if dist.is_main():
self.params.to_yaml(self.exp.params_fn)
+ self.exp.dump_info('params', self.params.to_dict())
self.set_global_steps(0)
self.set_epoch_idx(0)
self.set_idx(0)
if params.get('debug', False):
- self.exp.set_prop('debug', True)
+ self.exp.dump_info('debug', True)
@property
def metrics(self):
@@ -113,7 +123,12 @@ def metrics(self):
@property
def db(self):
- return self.database
+ return self._database
+
+ @property
+ def database(self):
+ warnings.warn('TableRow is deprecated and will be removed soon, please use self.metric instead')
+ return self._database
@property
def saver(self) -> Saver:
@@ -221,6 +236,7 @@ def writer(self):
res = SummaryWriter(**kwargs)
def close(*args):
+ """close writer"""
res.flush()
res.close()
@@ -257,73 +273,167 @@ def eidx(self):
@property
def global_steps(self) -> int:
- # started from 0
+ """started from 0"""
return self._prop['global_steps']
@property
- def trainer_state(self):
+ def trainer_state(self) -> Any:
+ """
+ Get the state of the Trainer object.
+
+ Returns:
+ Any: The state of the Trainer object.
+ """
return self._prop
@property
def devices(self) -> Dict[str, torch.device]:
- return self._state_dicts['devices']
+ """
+ Get the dictionary of devices used in the training session.
+
+ Returns:
+ Dict[str, torch.device]: A dictionary containing the devices used in the training session.
+ """
+ return {key: self[key] for key in self._state_dicts['devices']}
@property
def model_dict(self) -> Dict[str, nn.Module]:
- return {key: self[key]
- for key in self._state_dicts['models']}
+ """
+ Get the dictionary of model objects used in the training session.
+
+ Returns:
+ Dict[str, nn.Module]: A dictionary containing the model objects used in the training session.
+ """
+ return {key: self[key] for key in self._state_dicts['models']}
@property
def optim_dict(self) -> Dict[str, Optimizer]:
+ """
+ Get the dictionary of optimizer objects used in the training session.
+
+ Returns:
+ Dict[str, Optimizer]: A dictionary containing the optimizer objects used in the training session.
+ """
return {key: self[key] for key in self._state_dicts['optims']}
@property
def torch_tensor(self) -> Dict[str, torch.Tensor]:
+ """
+ Get the dictionary of PyTorch tensor objects used in the training session.
+
+ Returns:
+ Dict[str, torch.Tensor]: A dictionary containing the PyTorch tensor objects used in the training session.
+ """
return {key: self[key] for key in self._state_dicts['tensor.th']}
@property
def numpy_tensor(self) -> Dict[str, np.ndarray]:
+ """
+ Get the dictionary of NumPy array objects used in the training session.
+
+ Returns:
+ Dict[str, np.ndarray]: A dictionary containing the NumPy array objects used in the training session.
+ """
return {key: self[key] for key in self._state_dicts['tensor.np']}
@property
def others(self) -> Dict[str, Any]:
+ """
+ A dictionary of additional attributes stored in the Trainer.
+
+ Returns:
+ Dict[str, Any]: The dictionary of additional attributes.
+ """
return {key: self[key] for key in self._state_dicts['others']}
@property
def datamodule(self) -> DataModule:
+ """
+ Returns the DataModule associated with this Trainer.
+
+ Returns:
+ DataModule: The DataModule associated with this Trainer.
+ """
return self.dm
@property
def train_dataloader(self) -> Optional[DataLoaderType]:
+ """
+ Returns the DataLoader for the training data.
+
+ Returns:
+ Optional[DataLoaderType]: The DataLoader for the training data, or None if it is not available.
+ """
return self.datamodule['train']
@property
def test_dataloader(self) -> Optional[DataLoaderType]:
+ """
+ Returns the DataLoader for the test data.
+
+ Returns:
+ Optional[DataLoaderType]: The DataLoader for the test data, or None if it is not available.
+ """
return self.datamodule['test']
@property
def val_dataloader(self) -> Optional[DataLoaderType]:
+ """
+ Returns the DataLoader for the validation data.
+
+ Returns:
+ Optional[DataLoaderType]: The DataLoader for the validation data, or None if it is not available.
+ """
return self.datamodule['val']
@property
def device(self):
+ """
+ Returns the device used for training.
+
+ Returns:
+ The device used for training.
+ """
return self.accelerate.device
- def _load_fun_state_dict(self, src: dict, tgt: dict):
- for k, v in tgt.items():
- if k in src:
- v.load_state_dict(src[k])
+ def _load_fun_state_dict(self, src: dict):
+ """
+ Loads state dicts into the Trainer's attributes.
+
+ Args:
+ src (dict): A dictionary of state dicts to be loaded.
+ """
+ for k, v in src.items():
+ if self._rev_index.get(k, None) is not None:
+ self[k].load_state_dict(v)
def regist_dataloader(self, dataloader: DataLoader, stage: TrainStage):
+ """
+ Registers a dataloader with a given training stage to the current datamodule.
+
+ Args:
+ dataloader (DataLoader): The dataloader to be registered.
+ stage (TrainStage): The training stage to which the dataloader will be associated.
+
+ Returns:
+ None
+ """
self.datamodule.regist_dataloader_with_stage(stage, dataloader)
def process_loader(self, dm: Union[DataModule, DataLoader] = None, stage: TrainStage = TrainStage.train):
"""
- automatically called before train()/test()/evaluate(), see __new__ function of Trainer
- :param dm:
- :param stage:
- :return:
+ Prepares and registers a dataloader with the given training stage to the current datamodule.
+
+ Args:
+ dm (Union[DataModule, DataLoader], optional): The datamodule or dataloader to be processed. If not provided,
+ the current datamodule will be used if it exists.
+ stage (TrainStage, optional): The training stage to which the dataloader will be associated. Defaults to TrainStage.train.
+
+ Returns:
+ DataLoader: The prepared and registered dataloader.
+ None: If the dataloader cannot be prepared or registered.
"""
+
assert stage is not None, '`stage` cannot be None'
if dm is None and self.dm is not None:
dm = self.dm
@@ -347,12 +457,24 @@ def process_loader(self, dm: Union[DataModule, DataLoader] = None, stage: TrainS
return loader
def save_state_dict(self, name='latest.pth', dirpath=None, only_main=True):
+ """
+ Saves the current state dictionary to a file.
+
+ Args:
+ name: The name of the file to save the state dictionary to. Defaults to 'latest.pth'.
+ dirpath: The directory path to save the state dictionary file to. If None, defaults to the state dictionary
+ directory of the Trainer's experiment.
+ only_main: If True, saves the state dictionary to a single file. If False and the Trainer is distributed,
+ saves the state dictionary to multiple files, one for each process.
+
+ Returns:
+ The path to the saved state dictionary file.
+ """
if not only_main and self.is_dist:
pre, ext = os.path.splitext(name)
name = f'{pre}-{self.local_rank}{ext}'
if dirpath is None:
-
- fn = self.exp.state_dict_dir
+ fn = os.path.join(self.exp.state_dict_dir, name)
else:
fn = os.path.join(dirpath, name)
torch.save(self.state_dict(), fn)
@@ -360,19 +482,27 @@ def save_state_dict(self, name='latest.pth', dirpath=None, only_main=True):
return fn
def load_state_dict(self, state_dict: dict):
+ """Load state dictionary from a given dictionary.
+ Args:
+ state_dict (dict): A dictionary containing the state dictionary to be loaded.
+
+ Returns:
+ None
+ """
_sub = {'models', 'optims', 'other'}
_missing = []
for k, v in state_dict.items():
if k in _sub:
- self._load_fun_state_dict(v, self._state_dicts[k])
+ self._load_fun_state_dict(v)
else:
- self._state_dicts[k] = v
+ for kk, vv in v.items():
+ self[kk] = vv
return
def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapping]] = None,
device: torch.device = None):
-
+ """Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device."""
if item is None:
for k, v in list(self.model_dict.items()):
self[k] = self.accelerate.prepare(v)
@@ -385,39 +515,35 @@ def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapp
return item
def on_trainer_exception(self, func: Callable, exception: BaseException):
- self.database.update_dict(dict(end=datetime.now(),
- finished=False,
- error=str(exception),
- trainer_frame=str(func)),
- flush=True)
+ """Updates database with error information when an exception occurs during training."""
+ self.exp.dump_info('exception', dict(end=strftime(),
+ finished=False,
+ error=str(exception),
+ trainer_frame=str(func)))
@property
def is_initialized(self):
+ """Whether this Trainer is initialized."""
if self._prop.get('initial', False):
return True
return False
def initialize(self):
+ """
+ Initializes the Trainer object, update meta information in Experiment and TableRow.
+
+ If the Trainer object is already initialized, this method does nothing.
+
+ This function is auto called when start train()/test()/evaluate()
+ """
+
if self.is_initialized:
return
self.exp.start()
- commit_info = self.exp.get_prop('git')
- commit_hex = None
- if commit_info is not None and 'commit' in commit_info:
- commit_hex = commit_info['commit']
- self.database.update('commit_hex', commit_hex)
- self.database.update_dict(dict(
- test_name=self.exp.test_name,
- exp_name=self.exp.exp_name,
- project=self.exp.project_name,
- path=self.exp.test_root,
- start=datetime.now()))
- self.database.set_params(self.params.to_dict())
- self.database.update('command', ' '.join(sys.argv))
params_hash = self.params.hash()
- self.database.update('params_hash', params_hash)
self.exp.dump_string('params_hash', params_hash)
+
self.icallbacks(self.params)
self.set_property('initial.callbacks', True)
self.imodels(self.params)
@@ -425,10 +551,12 @@ def initialize(self):
self.set_property('initial', True)
def stop_train(self):
+ """Toggle to stop train."""
self.train_toggle = True
self.train_epoch_toggle = True
def stop_train_epoch(self):
+ """Toggle to skip current train epoch."""
self.train_epoch_toggle = True
def prepare_dataloader(self, loader: DataLoaderType, stage: TrainStage = None):
@@ -456,6 +584,21 @@ def prepare_dataloader(self, loader: DataLoaderType, stage: TrainStage = None):
return loader
def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType = None, limit_global_steps=None):
+ """Trains the model using the specified data loader and parameters.
+
+ Args:
+ dm (Union[DataModule, DataLoaderType], optional): The data loader or data module to use for training.
+ Defaults to self.train_dataloader.
+ params (ParamsType, optional): The training parameters to use. Defaults to None.
+ limit_global_steps (int, optional): The maximum number of global steps to train for. Defaults to None.
+
+ Returns:
+ Dict[str, Any]: A dictionary of training results.
+
+ Raises:
+ ValueError: If no data loader is available for training.
+
+ """
loader = self.select_loader(dm)
if not loader:
loader = self.train_dataloader
@@ -490,13 +633,23 @@ def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType
# update when train finished
self.exp.end()
- self.database.update_dict(dict(end=datetime.now(), finished=True), flush=True)
- self.database.flush()
return self._prop
def train_epoch(self, loader: DataLoaderType, params: ParamsType = None,
limit_step=None,
limit_global_steps=None) -> Record:
+ """Trains the model for one epoch using the specified data loader and parameters.
+
+ Args:
+ loader (DataLoaderType): The data loader to use for training.
+ params (ParamsType, optional): The training parameters to use. Defaults to None.
+ limit_step (int, optional): The maximum number of steps to train for. Defaults to None.
+ limit_global_steps (int, optional): The maximum number of global steps to train for. Defaults to None.
+
+ Returns:
+ Record: A record of training results for the epoch.
+
+ """
stage = TrainStage.train
self.change_stage(stage)
record = self.create_record(stage=stage)
@@ -520,22 +673,57 @@ def train_epoch(self, loader: DataLoaderType, params: ParamsType = None,
self._prop['global_steps'] += 1
metric = self.train_step(batch, params)
record.record(metric)
- self.database.flush()
record.flush()
- self.database.update_dict(dict(eidx=self.eidx, end=datetime.now()))
return record
- def set_property(self, key, value):
+ def set_property(self, key: str, value: any) -> None:
+ """
+ Sets a property with the given key to the given value.
+
+ Args:
+ key: A string representing the name of the property.
+ value: The value to assign to the property.
+
+ Returns:
+ None
+ """
self._prop[key] = value
- def set_global_steps(self, val):
+ def set_global_steps(self, val: int) -> None:
+ """
+ Sets the global step count to the given value.
+
+ Args:
+ val: An integer representing the global step count.
+
+ Returns:
+ None
+ """
self.set_property('global_steps', val)
- def set_epoch_idx(self, val):
+ def set_epoch_idx(self, val: int) -> None:
+ """
+ Sets the current epoch index to the given value.
+
+ Args:
+ val: An integer representing the current epoch index.
+
+ Returns:
+ None
+ """
self.set_property('eidx', val)
- def set_idx(self, val):
+ def set_idx(self, val: int) -> None:
+ """
+ Sets the current index to the given value.
+
+ Args:
+ val: An integer representing the current index.
+
+ Returns:
+ None
+ """
self.set_property('idx', val)
@property
@@ -543,13 +731,23 @@ def trainstage(self) -> TrainStage:
return self._prop.get('stage', TrainStage.default)
def set_stage(self, val: TrainStage):
+ """
+ Sets the training stage to the given value.
+
+ Args:
+ val (TrainStage): The value to set the training stage to.
+ """
self.set_property('stage', val)
def add_callback(self, callback):
"""
- 添加一个回调函数,注意,不能添加重复的 callback,这不推荐,也没有必要。
- :param callback:
- :return:
+ Adds a callback function. Note that duplicate callbacks are not recommended and not necessary.
+
+ Args:
+ callback: The callback function to add.
+
+ Returns:
+ bool: True if the callback was added successfully, False otherwise.
"""
msg = None
cb_name = callback.__class__.__name__
@@ -579,10 +777,22 @@ def add_callback(self, callback):
return True
def remove_callback(self, cur):
+ """
+ Removes the given callback from the list of callbacks.
+
+ Args:
+ cur: The callback to remove.
+ """
self.callbacks.remove(cur)
pass
def change_stage(self, stage: TrainStage):
+ """
+ Changes the training stage to the given value.
+
+ Args:
+ stage (TrainStage): The value to change the training stage to.
+ """
if self.trainstage == stage:
return
@@ -597,6 +807,15 @@ def change_stage(self, stage: TrainStage):
@classmethod
def select_loader(cls, dm=None):
+ """
+ Selects the appropriate loader based on the given data module.
+
+ Args:
+ dm (DataModule or DataLoader or DataLoaderSide, optional): The data module to use. Defaults to None.
+
+ Returns:
+ DataLoader or None: The appropriate loader based on the given data module, or None if dm is None.
+ """
loader = None
if dm:
if isinstance(dm, DataModule):
@@ -608,6 +827,16 @@ def select_loader(cls, dm=None):
return loader
def test(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = None, limit_step=None):
+ """
+ Tests the model on a given dataset and returns a `Record` object containing the evaluation results.
+
+ Args:
+ dm (Union[DataModule, DataLoader], optional): A `DataModule` or `DataLoader` object for the dataset to test on.
+ params (ParamsType, optional): A dictionary containing hyperparameters for the test.
+ limit_step (int, optional): An integer specifying the maximum number of batches to test on.
+ Returns:
+ A `Record` object containing the evaluation results.
+ """
stage = TrainStage.test
self.change_stage(stage)
@@ -635,6 +864,15 @@ def test(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = No
return record
def evaluate(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = None, limit_step: int = None):
+ """
+ Evaluates the model on a given dataset and returns a `Record` object containing the evaluation results.
+ Args:
+ dm (Union[DataModule, DataLoader], optional): A `DataModule` or `DataLoader` object for the dataset to evaluate on.
+ params (ParamsType, optional): A dictionary containing hyperparameters for the evaluation.
+ limit_step (int, optional): An integer specifying the maximum number of batches to evaluate on.
+ Returns:
+ A `Record` object containing the evaluation results.
+ """
stage = TrainStage.val
self.change_stage(stage)
@@ -660,59 +898,108 @@ def evaluate(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType
return record
def train_step(self, batch, params: ParamsType = None) -> MetricType:
+ """
+ Runs a single training step on a batch of data and returns a dictionary of training metrics.
+ Args:
+ batch: A batch of data to train on.
+ params (ParamsType, optional): A dictionary containing hyperparameters for the training step.
+ Returns:
+ A dictionary of training metrics.
+ """
pass
def test_step(self, batch, params: ParamsType = None) -> MetricType:
+ """
+ Runs a single testing step on a batch of data and returns a dictionary of evaluation metrics.
+ Args:
+ batch: A batch of data to test on.
+ params (ParamsType, optional): A dictionary containing hyperparameters for the testing step.
+ Returns:
+ A dictionary of evaluation metrics.
+ """
pass
def evaluate_step(self, batch, params: ParamsType = None) -> MetricType:
+ """
+ Runs a single evaluation step on a batch of data and returns a dictionary of evaluation metrics.
+ Args:
+ batch: A batch of data to evaluate on.
+ params (ParamsType, optional): A dictionary containing hyperparameters for the evaluation step.
+ Returns:
+ A dictionary of evaluation metrics.
+ """
pass
def imodels(self, params: ParamsType):
+ """Initialize model in here"""
pass
def icallbacks(self, params: ParamsType):
+ """Initialize callbacks in here"""
pass
def inference(self, batch):
+ """Perform inference on a batch of data."""
raise NotImplementedError()
def predict(self, batch):
+ """Make a prediction on a batch of data."""
raise NotImplementedError()
def optim_state_dict(self, wrap=True):
+ """Get a dictionary of the state of the optimizers."""
res = {k: v.state_dict() for k, v in self.optim_dict.items()}
if wrap:
res = {'optim': res}
return res
def model_state_dict(self, wrap=True):
+ """Get a dictionary of the state of the models."""
res = {k: self.accelerate.unwrap_model(v).state_dict() for k, v in self.model_dict.items()}
if wrap:
res = {'model': res}
return res
def other_state_dict(self, wrap=True):
+ """Get a dictionary of the state of the other objects."""
res = {k: v.state_dict() for k, v in self.others.items()}
if wrap:
res = {'other': res}
return res
def state_dict(self):
+ """Get a dictionary of the state of the object."""
res = {
'optims': self.optim_state_dict(wrap=False),
'models': self.model_state_dict(wrap=False),
'others': self.other_state_dict(wrap=False),
'thtensor': self.torch_tensor,
'nptensor': self.numpy_tensor,
+ # 'devices': self.devices,
}
return res
def Meter(self):
+ """
+ Returns a new instance of the Meter class.
+
+ Returns:
+ Meter: A new instance of the Meter class.
+ """
return Meter()
def create_record(self, stage: TrainStage = None):
+ """
+ Creates a new Record object with the specified TrainStage.
+
+ Args:
+ stage (TrainStage, optional): The TrainStage to use for the new Record object. If not provided, the TrainStage
+ from the Trainer object will be used.
+
+ Returns:
+ Record: A new Record object with the specified TrainStage.
+ """
if stage is None:
stage = self.trainstage
record = Record(stage=stage)
@@ -720,35 +1007,33 @@ def create_record(self, stage: TrainStage = None):
def wait_for_everyone(self):
"""
- making sure all processes have reached this point before continuing.
+ Will stop the execution of the current process until every other process has reached that point
"""
self.accelerate.wait_for_everyone()
- def save_model(self, is_best=False, meta_info: Union[str, dict] = None):
- info = self._build_trainer_meta_info(meta_info)
- val = self.saver.save_model(self.eidx, self.model_state_dict(),
- meta_info=info,
- is_best=is_best)
+ def save_best_model(self):
+ if self.is_main:
+ file = self.exp.mk_bpath('models', 'best_model.ckpt')
+ file_info = self.exp.mk_bpath('models', 'best_model.json')
+ else:
+ file = self.exp.mk_bpath('models', f'best_model-{self.local_rank}.ckpt')
+ file_info = self.exp.mk_bpath('models', f'best_model-{self.local_rank}.json')
+ torch.save(self.state_dict(), file)
+
+ with open(file_info, 'w') as w:
+ w.write(json.dumps({'global_steps': self.global_steps, 'metric': self.exp.metric.value}))
+ self.logger.info(f'saved best model at {file}')
self.wait_for_everyone()
- return val
-
- def _build_trainer_meta_info(self, meta_info: Union[str, dict] = None):
- info = dict()
- info['eidx'] = self.eidx
- if meta_info is not None:
- if isinstance(meta_info, str):
- info['msg'] = meta_info
- if isinstance(meta_info, Meter):
- meta_info = meta_info.serialize()
- if isinstance(meta_info, dict):
- info.update(meta_info)
- return info
-
- def save_checkpoint(self, max_keep=10, is_best=False, meta_info: Union[str, dict, Meter] = None):
- info = self._build_trainer_meta_info(meta_info)
- val = self.saver.save_checkpoint(self.eidx, self.state_dict(),
- meta_info=info,
- max_keep=max_keep,
- is_best=is_best)
+
+ def save_last_model(self):
+ if self.is_main:
+ file = self.exp.mk_bpath('models', 'last_model.ckpt')
+ file_info = self.exp.mk_bpath('models', 'last_model.json')
+ else:
+ file = self.exp.mk_bpath('models', f'last_model-{self.local_rank}.ckpt')
+ file_info = self.exp.mk_bpath('models', f'last_model-{self.local_rank}.json')
+ torch.save(self.state_dict(), file)
+ with open(file_info, 'w') as w:
+ w.write(json.dumps({'global_steps': self.global_steps, 'metric': self.exp.metric.value}))
+ self.logger.info(f'saved last model at {file}')
self.wait_for_everyone()
- return val
diff --git a/src/lumo/utils/ast.py b/src/lumo/utils/ast.py
index 9798ee81..b23c4b09 100644
--- a/src/lumo/utils/ast.py
+++ b/src/lumo/utils/ast.py
@@ -3,6 +3,17 @@
def analyse_module_dependency(module, mem=None, root=None):
+ """
+ Recursively analyse the dependencies of a module and return a dictionary of modules and their associated files.
+
+ Args:
+ module (module): The module to analyse.
+ mem (dict, optional): A dictionary of previously analysed modules and their associated files.
+ root (str, optional): The root directory to use as a reference when determining whether a file is a dependency.
+
+ Returns:
+ dict: A dictionary of modules and their associated files.
+ """
if mem is None:
mem = {}
diff --git a/src/lumo/utils/exithook.py b/src/lumo/utils/exithook.py
index 4d18702e..4f6f29c6 100644
--- a/src/lumo/utils/exithook.py
+++ b/src/lumo/utils/exithook.py
@@ -1,20 +1,24 @@
"""
-Do something you want before the programe exit.
+This module provides functions to wrap or replace the default exception handler of Python's sys module.
"""
import sys
from functools import wraps
def replace(func):
+ """Replace the default exception handler with the provided function."""
sys.excepthook = func
def wrap_after(func):
+ """Wrap the default exception handler, executing the provided function after it."""
old = sys.excepthook
def outer(fun):
+ """wrap function"""
@wraps(fun)
def inner(*args, **kwargs):
+ """wrap function"""
old(*args, **kwargs)
fun(*args, **kwargs)
@@ -24,11 +28,14 @@ def inner(*args, **kwargs):
def wrap_before(func):
+ """Wrap the default exception handler, executing the provided function before it."""
old = sys.excepthook
def outer(fun):
+ """wrap function"""
@wraps(fun)
def inner(*args, **kwargs):
+ """wrap function"""
fun(*args, **kwargs)
old(*args, **kwargs)
diff --git a/src/lumo/utils/filelock.py b/src/lumo/utils/filelock.py
index 8f752027..6c20b23b 100644
--- a/src/lumo/utils/filelock.py
+++ b/src/lumo/utils/filelock.py
@@ -1,58 +1,35 @@
-import time
-
+from filelock import Timeout, FileLock
import os
-import random
-
-from lumo.utils.exithook import wrap_before
class Lock:
- def __init__(self, name, sleep=1):
- from lumo.proc.path import cache_dir
- self.file = os.path.join(cache_dir(), f"LUMO_LOCK_{name}")
- self.sleep = sleep
- wrap_before(self.clear)
+ """
+ A class for obtaining and releasing file-based locks using FileLock.
- def clear(self, *_, **__):
- self.release()
+ Args:
+ name (str): The name of the lock.
- def abtain(self):
- mulp = 1
- while True:
- mulp += 1
- if mulp > 10:
- raise TimeoutError(f'Can not abtain resource of {self.file}')
-
- while os.path.exists(self.file):
- time.sleep(random.randint(mulp, mulp ** 2))
- mulp += 1
- if mulp > 10:
- raise TimeoutError(f'Can not abtain resource of {self.file}')
-
- while True:
- flag = f'{os.getpid()}'
- with open(self.file, 'w') as w:
- w.write(flag)
-
- if os.path.exists(self.file):
- with open(self.file, 'r') as r:
- lock_flag = r.read()
- if flag == lock_flag:
- return True
- mulp += 1
- time.sleep(random.randint(mulp, mulp ** 2))
- if mulp > 10:
- raise TimeoutError(f'Can not abtain resource of {self.file}')
+ Attributes:
+ fn (str): The file path of the lock file.
+ lock (FileLock): The FileLock object used for obtaining and releasing the lock.
- def release(self):
- try:
- os.remove(self.file)
- except:
- pass
- return True
+ Example:
+ lock = Lock('my_lock')
+ lock.abtain()
+ # critical section
+ lock.release()
+ """
- def __enter__(self):
- self.abtain()
+ def __init__(self, name):
+ """Initialize the lock file path and FileLock object"""
+ from lumo.proc.path import cache_dir
+ self.fn = os.path.join(cache_dir(), f"LUMO_LOCK_{name}")
+ self.lock = FileLock(self.fn)
+
+ def abtain(self):
+ """Acquire the lock"""
+ self.lock.acquire()
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.release()
+ def release(self):
+ """Release the lock"""
+ self.lock.release()
diff --git a/src/lumo/utils/filelock2.py b/src/lumo/utils/filelock2.py
deleted file mode 100644
index f18916cc..00000000
--- a/src/lumo/utils/filelock2.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import time
-
-import os
-import random
-from joblib import hash
-import mmap
-from lumo.utils.exithook import wrap_before
-
-
-class Lock:
- def __init__(self, name, sleep=1, group=None):
- from lumo.proc.path import cache_dir
- if group is None:
- group = ""
-
- self.size = 10000
- self.flag = f'{os.getpid()}'
- self.flagsize = len(self.flag)
- self.pos = int(hash(name), 16) % (self.size - len(self.flag))
- self.file = os.path.join(cache_dir(), f'LUMO_LOCK_V3_{group}')
- if not os.path.exists(self.file):
- with open(self.file, 'w') as w:
- w.write('0' * self.size)
-
- self.sleep = sleep
- self.fhdl = None
- wrap_before(self.clear)
-
- def clear(self, *_, **__):
- self.release()
-
- def abtain(self):
- mulp = 0
- self.fhdl = r = open(self.file, 'r+b')
- mm = mmap.mmap(r.fileno(), 0, access=mmap.ACCESS_WRITE)
- flag = f'{os.getpid()}'
- while True:
- mulp += 1
- if mulp > 10:
- raise TimeoutError(f'Can not abtain resource of {self.file}')
-
- if mm[self.pos:self.pos + len(flag)].decode() != '0' * len(flag):
- time.sleep(random.randint(mulp, mulp ** 2))
- continue
-
- mm[self.pos:self.pos + len(flag)] = flag.encode()
- mm.flush()
-
- if mm[self.pos:self.pos + len(flag)].decode() != flag:
- mm.close()
- time.sleep(random.randint(mulp, mulp ** 2))
- continue
-
- return True
-
- def release(self):
- if self.fhdl is None:
- return
-
- r = self.fhdl
- mm = mmap.mmap(r.fileno(), 0, access=mmap.ACCESS_WRITE)
- flag = b'0' * self.flagsize
- mm[self.pos:self.pos + len(flag)] = flag
- mm.flush()
- r.close()
- self.fhdl = None
- return True
-
- def __enter__(self):
- self.abtain()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.release()
diff --git a/src/lumo/utils/fmt.py b/src/lumo/utils/fmt.py
index 480c812f..7ba7897c 100644
--- a/src/lumo/utils/fmt.py
+++ b/src/lumo/utils/fmt.py
@@ -2,7 +2,7 @@
Format date/filename, check array shape, convert item from torch.Tensor to ndarray or scalar.
"""
import textwrap
-from datetime import datetime
+from datetime import datetime, timedelta
import numpy as np
import torch
@@ -11,18 +11,21 @@
def to_ndarray(item):
+ """Convert a PyTorch tensor or any other array-like object to a NumPy ndarray."""
if isinstance(item, torch.Tensor):
item = item.detach().cpu()
return np.array(item)
def detach(item):
+ """Detach a PyTorch tensor and convert it to a NumPy ndarray."""
if isinstance(item, torch.Tensor):
item = item.detach().cpu().numpy()
return item
def validate_scalar_shape(ndarray, name=''):
+ """Validate that a given numpy array is a scalar."""
if ndarray.ndim != 0:
raise ValueError(
"Expected scalar value for %r but got %r" % (name, ndarray)
@@ -31,6 +34,7 @@ def validate_scalar_shape(ndarray, name=''):
def is_scalar(ndarray: np.ndarray):
+ """Check whether a numpy array is a scalar."""
return ndarray.size == 1
@@ -41,7 +45,8 @@ def strftime(fmt='%y-%m-%d-%H%M%S', dateobj: datetime = None):
return datetime.now().strftime(fmt)
-def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None):
+def strptime(datestr: str = None, fmt='%y-%m-%d-%H%M%S', ):
+ """Convert a string to a datetime object using the specified format."""
return datetime.strptime(datestr, fmt)
@@ -52,12 +57,38 @@ def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None):
def to_filename(basename):
+ """Replace invalid characters in a basename with an underscore."""
return re.sub(_invalid_fc, '_', basename)
def can_be_filename(basename):
+ """Check whether a basename can be converted to a valid filename."""
return re.search(_invalid_fc, basename) is None
def indent_print(text, indent=' '):
+ """Prints the specified text with a given indentation."""
print(textwrap.indent(text, indent))
+
+
+def format_second(sec: int) -> str:
+ """Formats a duration given in seconds into a human-readable string."""
+ sec, ms = divmod(sec, 1)
+ if sec > 60:
+ min, sec = divmod(sec, 60)
+ if min > 60:
+ hour, min = divmod(min, 60)
+ if hour > 24:
+ day, hour = divmod(hour, 24)
+ fmt = "{}d{}h{}m{}s".format(int(day), int(hour), int(min), int(sec))
+ else:
+ fmt = "{}h{}m{}s".format(int(hour), int(min), int(sec))
+ else:
+ fmt = "{}m{}s".format(int(min), int(sec))
+ else:
+ fmt = "{}s".format(int(sec))
+ return fmt
+
+
+def format_timedelta(td: timedelta):
+ return format_second(td.total_seconds())
diff --git a/src/lumo/utils/hash.py b/src/lumo/utils/hash.py
deleted file mode 100644
index 75f173cc..00000000
--- a/src/lumo/utils/hash.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from hashlib import md5
-
-
-def hash_iter(*object: str):
- hasher = md5()
- for i in object:
- hasher.update(i.encode())
- return hasher.hexdigest()
diff --git a/src/lumo/utils/logger.py b/src/lumo/utils/logger.py
index f590c353..04c3e4f2 100644
--- a/src/lumo/utils/logger.py
+++ b/src/lumo/utils/logger.py
@@ -29,6 +29,12 @@
def _get_print_func():
+ """
+ Returns the `rich.print` function if available, or the built-in `print` function if not.
+
+ Returns:
+ callable: The `rich.print` function if available, or the built-in `print` function if not.
+ """
try:
from rich import print
except ImportError:
@@ -37,6 +43,12 @@ def _get_print_func():
def get_global_logger():
+ """
+ Returns the global logger object, creating it if it does not exist.
+
+ Returns:
+ logging.Logger: The global logger object.
+ """
global logger
if logger is None:
logger = Logger()
@@ -44,18 +56,39 @@ def get_global_logger():
def set_global_logger(logger_):
+ """
+ Sets the global logger object to the specified logger instance.
+
+ Args:
+ logger_ (logging.Logger, optional): The logger instance to set as the global logger. If `None`,
+ a new logger instance will be created. Defaults to `None`.
+
+ Returns:
+ logging.Logger: The global logger object.
+ """
global logger
logger = logger_
return logger
def process_str():
+ """
+ Returns a string representing the current process, suitable for use in log messages or other output.
+
+ Returns:
+ str: A string representing the current process. If running in a distributed context (as determined
+ by `is_dist()`), the string will be of the form '[local_rank]', where 'local_rank' is the local
+ rank of the current process. Otherwise, an empty string will be returned.
+ """
if is_dist():
return f'[{local_rank()}]'
return ''
class Logger:
+ """A logger adapted for deep learning training experiments that can print output without a newline and
+ adds process index in distributed training.
+ """
VERBOSE = 20
VVV_DEBUG = VVV_DEBUG
VV_DEBUG = VV_DEBUG
@@ -230,6 +263,24 @@ def print(self, *args, end='\n', file=sys.stdout):
self._print_func(*args, end=end, flush=True, file=file)
def print_rich(self, *args, end='\n', file=sys.stdout):
+ """
+ Prints the specified arguments to the given file object, using rich formatting if available.
+
+ Args:
+ *args: The arguments to be printed.
+ end: The string to append at the end of the printed output (default: '\n').
+ file: The file object to which the output should be directed (default: sys.stdout).
+
+ Returns:
+ None
+
+ Raises:
+ N/A
+
+ Notes:
+ - If self.use_stdout is True and rich formatting is available, the output will be formatted using rich.
+ - If self.use_stdout is True and rich formatting is not available, the output will be printed using the default print function.
+ """
if self.use_stdout:
if self._try_rich:
print = _get_print_func()
diff --git a/src/lumo/utils/memory_grab.py b/src/lumo/utils/memory_grab.py
index 8457229e..a49c31d7 100644
--- a/src/lumo/utils/memory_grab.py
+++ b/src/lumo/utils/memory_grab.py
@@ -16,10 +16,22 @@
class DeviceMem:
+ """
+ A class that represents device memory usage.
+ """
def __init__(self):
self.line_mem = tree()
def _parse_device_pid_mem_pair(self, lines):
+ """
+ Parses device ID, process ID, and memory usage (in MiB) from the given list of strings.
+
+ Args:
+ lines (List[str]): List of strings to parse.
+
+ Yields:
+ Tuple[int, int, int]: Tuple containing device ID, process ID, and memory usage (in MiB) for each match found.
+ """
for lid, line in enumerate(lines):
res = re.search(match_mem, line)
if res is not None:
@@ -28,12 +40,18 @@ def _parse_device_pid_mem_pair(self, lines):
yield _device, _pid, _mib
def try_parse(self, lines, pid, device):
- """ try parse mem from cached lid directly.
- Returns:
- -1 means failed.
- others means successd and its memory.
+ """
+ Attempts to parse memory usage (in MiB) for a process running on a specific device using the cached line ID.
+ Args:
+ lines (List[str]): List of strings to parse.
+ pid (int): Process ID to look for.
+ device (int or str or torch.device): Device ID to look for.
+
+ Returns:
+ int: Memory usage in MiB for the specified process and device if found, -1 otherwise.
"""
+
lid = self.line_mem[device][pid]
if isinstance(lid, dict):
return -1
@@ -51,6 +69,17 @@ def try_parse(self, lines, pid, device):
return -1
def re_parse(self, lines, pid, device):
+ """
+ Parses memory usage (in MiB) for a process running on a specific device by searching through the list of strings.
+
+ Args:
+ lines (List[str]): List of strings to parse.
+ pid (int): Process ID to look for.
+ device (int or str or torch.device): Device ID to look for.
+
+ Returns:
+ int: Memory usage in MiB for the specified process and device if found, 0 otherwise.
+ """
_res = self.try_parse(lines, pid, device)
if _res != -1:
return _res
@@ -62,11 +91,27 @@ def re_parse(self, lines, pid, device):
return 0
def _get_nvidia_smi(self):
+ """
+ Executes the 'nvidia-smi' command and returns the output as a list of strings.
+
+ Returns:
+ List[str]: List of strings representing the output of the 'nvidia-smi' command.
+ """
proc = subprocess.Popen(['nvidia-smi'], stdout=subprocess.PIPE)
lines = proc.stdout.readlines()
return [i.decode() for i in lines]
def _device_equal(self, da, db):
+ """
+ Compares two device IDs or names for equality.
+
+ Args:
+ da (int or str or torch.device): First device ID or name to compare.
+ db (int or str or torch.device): Second device ID or name to compare.
+
+ Returns:
+ bool: True if the two device IDs or names are equal, False otherwise.
+ """
if isinstance(da, (int, str)):
da = torch.device(da)
if isinstance(db, (int, str)):
@@ -74,7 +119,15 @@ def _device_equal(self, da, db):
return da == db
def get_device_release_mem(self, device):
- """ get device memory left."""
+ """
+ Returns the amount of free memory (in MiB) on a specified device.
+
+ Args:
+ device (int or str or torch.device): Device ID or name to look up.
+
+ Returns:
+ int: Amount of free memory (in MiB) on the specified device.
+ """
s_pid = os.getpid()
total = self.get_device_mem(device)
for _device, _pid, _mib in self._parse_device_pid_mem_pair(self._get_nvidia_smi()):
@@ -84,15 +137,27 @@ def get_device_release_mem(self, device):
return total
def get_device_mem(self, device):
- """ returns device total memory(unit: MB) """
+ """
+ Returns the total amount of memory (in MiB) on a specified device.
+
+ Args:
+ device (int or str or torch.device): Device ID or name to look up.
+
+ Returns:
+ int: Total amount of memory (in MiB) on the specified device.
+ """
return torch.cuda.get_device_properties(device).total_memory // (1024 * 1024)
def get_pid_device_mem(self, pid, device):
"""
- 尽可能有效率的得到进程在某设备下占用的显存(通过命令行程序调用获取)
- :param pid:
- :param device:
- :return:
+ Attempts to obtain the memory usage (in MiB) for a specific process running on a specific device.
+
+ Args:
+ pid (int): Process ID to look up.
+ device (int or str or torch.device): Device ID or name to look up.
+
+ Returns:
+ int: Memory usage in MiB for the specified process and device if found, -1 otherwise.
"""
if isinstance(device, torch.device):
device = device.index
@@ -115,17 +180,20 @@ def get_pid_device_mem(self, pid, device):
class memory(object):
- r"""
- 优雅的抢卡
+ """
+ A graceful memory allocator that optimizes GPU memory allocation by incrementally increasing the memory
+ footprint to minimize fragmentation.
+
Args:
- memory: 需要占用的内存,以 MB 为单位
- device: 需要占用内存的设备
- hold:
- unit:
- Example::
+ memory (int): Memory size to be allocated in MB.
+ device (str or int, optional): Device to allocate memory on. Defaults to the current CUDA device.
+ hold (bool, optional): Whether to hold the memory after allocation. Defaults to False.
+ invade (bool, optional): Whether to use aggressive memory allocation. Defaults to False.
+
+ Examples:
>>> import lumo
>>> with lumo.memory(5000):
- ... y = x * 2
+ ... y = x * 2
>>> @lumo.memory(1024)
... def doubler(x):
@@ -134,11 +202,21 @@ class memory(object):
>>> lumo.memory(10000).start()
... # do something
- Why use nvidia-smi to get memory useage? see:
+ References:
+ To get GPU memory usage, we use nvidia-smi. Refer to this link for details:
https://github.com/pytorch/pytorch/issues/12873
"""
def __init__(self, memory, device=None, hold=False, invade=False) -> None:
+ """
+ Initialize the memory allocator.
+
+ Args:
+ memory (int): Memory size to be allocated in MB.
+ device (str or int, optional): Device to allocate memory on. Defaults to the current CUDA device.
+ hold (bool, optional): Whether to hold the memory after allocation. Defaults to False.
+ invade (bool, optional): Whether to use aggressive memory allocation. Defaults to False.
+ """
super().__init__()
if device is None:
device = torch.cuda.current_device()
@@ -155,6 +233,14 @@ def __init__(self, memory, device=None, hold=False, invade=False) -> None:
self.last_success = _memer.get_pid_device_mem(_pid, self.device)
def copy(self, pid: int, wait: bool = True):
+ """
+ Copy memory allocation parameters from another process.
+
+ Args:
+ pid (int): Process ID to copy memory parameters from.
+ wait (bool, optional): Whether to wait until the other process has finished before starting allocation.
+ Defaults to True.
+ """
self.need = _memer.get_pid_device_mem(pid, self.device)
if wait:
self.wait(pid)
@@ -162,16 +248,25 @@ def copy(self, pid: int, wait: bool = True):
self.start()
def wait(self, pid):
+ """
+ Wait for the other process to finish before starting allocation.
+
+ Args:
+ pid (int): Process ID to wait for.
+ """
while _memer.get_pid_device_mem(pid, self.device) > 0:
time.sleep(0.5)
self.start()
def immediately(self, pre_init=False):
"""
- 等待,直到内存有空间后,开始申请相应显存,优雅,礼貌,推荐
+ Wait until there is enough memory available, then allocate the necessary memory.
+ This is the recommended way to allocate memory.
+
Args:
- pre_init: 是否初始化 CUDA(这将在一开始消耗一定显存),默认为 False,即不抢占任何内存,
- 直到设备释放足够空间后开始抢占。
+ pre_init (bool, optional): Whether to initialize CUDA (which will consume a certain amount of memory)
+ before allocating. Defaults to False, meaning memory will only be allocated after enough memory is
+ released by other processes.
"""
while True:
_left = _memer.get_device_release_mem(self.device)
@@ -202,7 +297,16 @@ def immediately(self, pre_init=False):
_left), end='\r')
def _malloc(self, size, init=False):
- """ unit: mb """
+ """
+ Allocate memory of the given size (in MB) on the specified device.
+
+ Args:
+ size (int): Memory size to be allocated, in MB.
+ init (bool, optional): Whether to initialize CUDA. Defaults to False.
+
+ Returns:
+ bool: True if the memory allocation was successful, False otherwise.
+ """
try:
tmp = torch.rand(size, 1048576 // 4, device=self.device)
if not init:
@@ -212,17 +316,33 @@ def _malloc(self, size, init=False):
return False
def end(self):
+ """
+ Release allocated memory and empty the CUDA cache.
+ """
del self.mem[:]
torch.cuda.empty_cache()
def start(self, immediately=True):
+ """
+ Start memory allocation.
+
+ Args:
+ immediately (bool, optional): Whether to use the recommended memory allocation method.
+ Defaults to True.
+ """
if immediately:
self.immediately()
else:
self.invade()
def invade(self, unit=5):
- """一点一点的侵占,有多少占用多少,直到申请满为止,比较粗鲁,不友好,不推荐"""
+ """
+ Aggressively allocate memory, increasing the memory footprint until the necessary amount is allocated.
+ This method is not recommended.
+
+ Args:
+ unit (int, optional): Incremental size of memory allocation (in MB). Defaults to 5.
+ """
try:
while self.last_success < self.need:
res = self._malloc(unit + self.acc)
@@ -243,15 +363,35 @@ def invade(self, unit=5):
print('\nabort.')
def __enter__(self):
+ """
+ Start memory allocation when entering the 'with' block.
+ """
self.start(immediately=not self.is_invade)
def __exit__(self, *args):
+ """
+ Release memory and empty the CUDA cache when exiting the 'with' block.
+
+ Returns:
+ bool: Always returns True.
+ """
self.end()
return True
def __call__(self, func):
+ """
+ Decorator to use with functions that require memory allocation.
+
+ Args:
+ func (callable): The function to decorate.
+
+ Returns:
+ callable: The decorated function.
+ """
+
@functools.wraps(func)
def decorate_no_grad(*args, **kwargs):
+ """decorate"""
with self:
return func(*args, **kwargs)
@@ -259,6 +399,9 @@ def decorate_no_grad(*args, **kwargs):
@staticmethod
def hold_current():
+ """
+ Hold the currently allocated memory.
+ """
count = torch.cuda.device_count()
mems = [_memer.get_pid_device_mem(_pid, i) for i in range(count)]
for i, mem in enumerate(mems):
diff --git a/src/lumo/utils/random.py b/src/lumo/utils/random.py
index 96accece..f6a0111d 100644
--- a/src/lumo/utils/random.py
+++ b/src/lumo/utils/random.py
@@ -11,10 +11,28 @@
def int_time():
+ """
+ Get the current time as an integer.
+
+ Returns:
+ int: The current time as an integer.
+ """
return int(str(time.time()).split(".")[-1])
def hashseed(hashitem: Union[int, str]):
+ """
+ Generate a hash seed from a given integer or string.
+
+ Args:
+ hashitem (Union[int, str]): The integer or string to generate the hash seed from.
+
+ Returns:
+ int: The hash seed.
+
+ Raises:
+ AssertionError: If the given `hashitem` is not an integer or a string.
+ """
if not isinstance(hashitem, (int, str)):
raise AssertionError()
@@ -52,6 +70,9 @@ def fix_seed(seed=10, cuda=True):
def fix_cuda():
+ """
+ Set deterministic and reproducible configuration for CUDA.
+ """
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
@@ -59,6 +80,17 @@ def fix_cuda():
def get_state(cuda=True):
+ """
+ Get the current state of the random number generators.
+
+ Args:
+ cuda (bool): Whether to get the CUDA state if PyTorch is using CUDA.
+
+ Returns:
+ dict: A dictionary containing the current states of the random number generators for
+ numpy, pytorch, pytorch.cuda, and python's built-in `random` module.
+
+ """
return {
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
@@ -69,12 +101,11 @@ def get_state(cuda=True):
def set_state(state_dict, cuda=True):
"""
- Set random state of built-in random, numpy, torch, torch.cuda
+ Set the random state for NumPy, PyTorch, PyTorch CUDA, and Python's built-in `random` module.
Args:
- state_dict: a dict got from `lumo.utils.random.get_state()`
-
- Returns:
+ state_dict (dict): A dictionary containing the desired states of the random number generators for NumPy, PyTorch, PyTorch CUDA, and Python's built-in `random` module.
+ cuda (bool): Whether to set the CUDA state if PyTorch is using CUDA.
"""
random.setstate(state_dict["random"])
diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py
index d65378f8..6466ab74 100644
--- a/src/lumo/utils/repository.py
+++ b/src/lumo/utils/repository.py
@@ -8,10 +8,20 @@
import git
from git import Repo, Commit
from joblib import hash
-from .filelock2 import Lock
+
+from .filelock import Lock
def dev_branch():
+ """
+ Returns the value of the 'dev_branch' key from the global configuration dictionary 'glob' in the 'lumo.proc.config' module.
+
+ If the key is not present in the dictionary, returns the default value of 'lumo_experiments'.
+
+ Returns:
+ str: The value of the 'dev_branch' key from the global configuration dictionary. By default is `lumo_experiments`.
+
+ """
from lumo.proc.config import glob
return glob.get('dev_branch', 'lumo_experiments')
@@ -21,10 +31,26 @@ def dev_branch():
class branch:
"""
- 用于配合上下文管理切换 git branch
+ A context manager class for switching git branches in a given repository.
+
+ Example usage:
+ with branch(repo, branch):
+ repo.index.commit('...')
+
+ Args:
+ repo (Repo): The repository object for which the branch will be switched.
+ branch (str): The name of the branch to switch to.
+
+
+ Notes:
+ This class provides a context manager that switches the current branch
+ to the given branch when entering the context and switches back to the original branch when exiting the context.
+
+ If the given branch does not exist in the repository, it will be created.
+
+ A lock is obtained on the repository to ensure that
+ only one instance of this class can switch branches at a time for a given repository.
- with branch(repo, branch):
- repo.index.commit('...')
"""
def __init__(self, repo: Repo, branch: str):
@@ -60,6 +86,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def check_have_commit(repo):
+ """
+ Checks if the given repository has any commits.
+
+ If there are no commits, creates an initial commit that adds all files in the repository and has the message "initial commit".
+ """
if len(repo.heads) == 0:
repo.git.add('.')
repo.index.commit('initial commit')
@@ -124,7 +155,7 @@ def git_commit(repo=None, key=None, branch_name=None, info: str = None, filter_f
# print(diff_uncommit)
if filter_files is not None:
- diff_from_branches = [i.a_path for i in diff_from_branches if i.a_path in filter_files]
+ diff_from_branches = [i for i in diff_from_branches if i.a_path in filter_files]
if len(diff_from_branches) == 0 and len(diff_uncommit) == 0 and len(repo.untracked_files) == 0:
commit_ = exp_head_commit
@@ -154,7 +185,14 @@ def git_commit(repo=None, key=None, branch_name=None, info: str = None, filter_f
return commit_
-def git_checkout(repo=None, commit_hex=None, commit: Commit = None):
+def git_archive(repo=None, commit_hex=None, commit: Commit = None):
+ """
+ git archive -o
+
+ Returns:
+ An Experiment represents this archive operation
+ """
+ from lumo.exp import Experiment
if repo is None:
repo = load_repo()
@@ -163,22 +201,36 @@ def git_checkout(repo=None, commit_hex=None, commit: Commit = None):
old_path = os.getcwd()
os.chdir(commit.tree.abspath)
+ exp = Experiment('GitArchive')
+ fn = exp.mk_bpath(f'{commit.hexsha[:8]}.tar')
- # with branch(commit.repo, LUMO_BRANCH) as new_branch:
- repo.git.checkout('-b', commit.hexsha[:8], commit.hexsha)
+ exp.dump_info('git_archive', {'file': fn,
+ 'test_name': exp.test_name,
+ 'commit_hex': commit.hexsha[:8]})
+ exp.dump_string('archive_fn', fn)
+ with open(fn, 'wb') as w:
+ repo.archive(w, commit.hexsha)
os.chdir(old_path)
- return commit.hexsha[:8]
+ return exp
-def git_archive(repo=None, commit_hex=None, commit: Commit = None):
+def git_checkout(repo=None, commit_hex=None, commit: Commit = None):
"""
- git archive -o
+ Checkout a specific commit in a Git repository.
+
+ Args:
+ repo (git.Repo, optional): The Git repository to use. Defaults to None, in which case the repository is loaded using `load_repo()`.
+ commit_hex (str, optional): The hash of the commit to check out. Defaults to None.
+ commit (git.Commit, optional): The commit object to check out. Defaults to None.
Returns:
- An Experiment represents this archive operation
+ str: The abbreviated hash of the checked-out commit.
+
+ Raises:
+ git.InvalidGitRepositoryError: If the specified repository is invalid or not found.
+ git.BadName: If the specified branch name is invalid or not found.
"""
- from lumo.exp import Experiment
if repo is None:
repo = load_repo()
@@ -187,22 +239,25 @@ def git_archive(repo=None, commit_hex=None, commit: Commit = None):
old_path = os.getcwd()
os.chdir(commit.tree.abspath)
- exp = Experiment('GitArchive')
- fn = exp.blob_file(f'{commit.hexsha[:8]}.tar')
- exp.dump_info('git_archive', {'file': fn,
- 'test_name': exp.test_name,
- 'commit_hex': commit.hexsha[:8]})
- exp.dump_string('archive_fn', fn)
- with open(fn, 'wb') as w:
- repo.archive(w, commit.hexsha)
+ # with branch(commit.repo, LUMO_BRANCH) as new_branch:
+ repo.git.checkout('-b', commit.hexsha[:8], commit.hexsha)
os.chdir(old_path)
- return exp
+ return commit.hexsha[:8]
@lru_cache(1)
def git_enable():
+ """
+ Check if Git is installed and a repository is present.
+
+ Returns:
+ bool: True if Git is installed and a repository is present, False otherwise.
+
+ Raises:
+ ImportError: If the `gitpython` library is not installed.
+ """
try:
import git
except ImportError:
@@ -220,18 +275,19 @@ def git_enable():
def git_dir(root='./'):
"""
git repository directory
- git rev-parse --git-dir
+ git rev-parse --show-toplevel
Args:
root:
Returns:
+ The original command, `git rev-parse --git-dir`, can not find a right path when the repository is a submodule inside another repository.
"""
if git_enable():
from git import Git
cur = os.getcwd()
os.chdir(root)
- res = Git().execute(['git', 'rev-parse', '--git-dir'])
- res = os.path.abspath(os.path.dirname(res))
+ res = Git().execute(['git', 'rev-parse', '--show-toplevel'])
+ res = os.path.abspath(res)
os.chdir(cur)
return res
else:
diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py
index b67085fd..b87d5881 100644
--- a/src/lumo/utils/safe_io.py
+++ b/src/lumo/utils/safe_io.py
@@ -17,39 +17,75 @@
def dump_json(obj, fn):
+ """
+ Dumps the given object to a JSON file at the given file path.
+
+ Args:
+ obj: The object to be dumped to JSON.
+ fn (str): The file path to which the JSON data will be written.
+
+ Notes:
+ The JSON data will be written with an indentation of 2 spaces.
+ """
with open(fn, 'w', encoding='utf-8') as w:
json.dump(obj, w, indent=2)
def dump_yaml(obj, fn):
+ """
+ Dumps the given object to a YAML file at the given file path.
+
+ Args:
+ obj: The object to be dumped to YAML.
+ fn (str): The file path to which the YAML data will be written.
+
+ Notes:
+ The YAML data will be written with default formatting options.
+ """
import yaml
with open(fn, 'w', encoding='utf-8') as w:
yaml.safe_dump(obj, w)
- return fn
def dump_state_dict(obj, fn):
+ """Saves a PyTorch state dictionary object to disk."""
torch.save(obj, fn)
- return fn
def load_json(fn):
- with open(fn, 'r', encoding='utf-8') as r:
- return json.load(r)
+ """
+ Loads JSON data from the given file path and returns the resulting object.
+
+ Args:
+ fn: file name
+
+ Returns:
+
+ Raises:
+ ValueError
+ """
+ try:
+ with open(fn, 'r', encoding='utf-8') as r:
+ return json.load(r)
+ except json.JSONDecodeError as e:
+ raise ValueError(f'Error in file {fn}') from e
def load_yaml(fn):
+ """Loads YAML data from the given file path and returns the resulting object."""
import yaml
with open(fn, 'r', encoding='utf-8') as r:
return yaml.safe_load(r)
def load_state_dict(fn: str, map_location='cpu'):
+ """Loads a PyTorch model checkpoint from the specified file path and returns its state dictionary."""
ckpt = torch.load(fn, map_location=map_location)
return ckpt
def load_text(fn):
+ """Loads text data from the given file path and returns it as a single string."""
if not os.path.exists(fn):
return ''
with open(fn, 'r', encoding='utf-8') as r:
@@ -57,6 +93,16 @@ def load_text(fn):
def dump_text(string: str, fn, append=False):
+ """Write the given string to a file.
+
+ Args:
+ string (str): The string to write.
+ fn (str): The filename to write to.
+ append (bool, optional): If True, append the string to the file. Otherwise, overwrite the file. Defaults to False.
+
+ Returns:
+ str: The filename that was written to.
+ """
mode = 'w'
if append:
mode = 'a'
@@ -66,6 +112,16 @@ def dump_text(string: str, fn, append=False):
def safe_getattr(self, key, default=None):
+ """Get an attribute of an object safely.
+
+ Args:
+ self (object): The object to get the attribute from.
+ key (str): The name of the attribute to get.
+ default (object, optional): The value to return if the attribute does not exist. Defaults to None.
+
+ Returns:
+ object: The value of the attribute if it exists, otherwise the default value.
+ """
try:
return getattr(self, key, default)
except:
@@ -73,6 +129,19 @@ def safe_getattr(self, key, default=None):
def dump_pkl(obj, file, make_path=True, protocol=None, *, fix_imports=True):
+ """Save an object to a pickle file.
+
+ Args:
+ obj (object): The object to save.
+ file (str or FileIO): The filename or file object to save to.
+ make_path (bool, optional): If True and file is a filename, create the directory path if it does not exist. Defaults to True.
+ protocol (int, optional): The pickle protocol to use. Defaults to None.
+ fix_imports (bool, optional): Whether to fix Python 2 to 3 pickle incompatibilities. Defaults to True.
+
+ Raises:
+ NotImplementedError: If the file type is not supported.
+
+ """
if isinstance(file, str):
if make_path:
os.makedirs(os.path.dirname(os.path.abspath(file)), exist_ok=True)
@@ -82,10 +151,25 @@ def dump_pkl(obj, file, make_path=True, protocol=None, *, fix_imports=True):
elif isinstance(file, FileIO):
_pickle.dump(obj, file, protocol=protocol, fix_imports=fix_imports)
else:
- raise NotImplementedError()
+ raise NotImplementedError("File type not supported.")
def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
+ """Load an object from a pickle file.
+
+ Args:
+ file (str or FileIO): The filename or file object to load from.
+ fix_imports (bool, optional): Whether to fix Python 2 to 3 pickle incompatibilities. Defaults to True.
+ encoding (str, optional): The character encoding to use. Defaults to "ASCII".
+ errors (str, optional): The error handling scheme to use. Defaults to "strict".
+
+ Returns:
+ object: The object that was loaded from the file.
+
+ Raises:
+ NotImplementedError: If the file type is not supported.
+
+ """
if isinstance(file, str):
file = open(file, 'rb')
res = _pickle.load(file, fix_imports=fix_imports, encoding=encoding, errors=errors)
@@ -94,11 +178,27 @@ def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
elif isinstance(file, FileIO):
return _pickle.load(file, fix_imports=fix_imports, encoding=encoding, errors=errors)
else:
- raise NotImplementedError()
+ raise NotImplementedError("File type not supported.")
@contextmanager
def cached(fn):
+ """
+ A context manager that caches the output of a computation to a file.
+
+ Args:
+ fn (str): The file path to which the cached data will be written.
+
+ Yields:
+ str: The file path of the cache file.
+
+ Examples:
+
+ with cached('a.txt') as cache_fn:
+ with open(cache_fn, 'w') as w:
+ w.write('123')
+
+ """
import shutil
cache_fn = f'{fn}.lumo_cache'
try:
diff --git a/src/lumo/utils/screen.py b/src/lumo/utils/screen.py
index fb41dc4b..a994a980 100644
--- a/src/lumo/utils/screen.py
+++ b/src/lumo/utils/screen.py
@@ -10,10 +10,23 @@
def get_consolo_width():
+ """Returns the width of the current console window"""
return shutil.get_terminal_size().columns - 1 # -1 for windows consolo
def support_multiline():
+ """
+ Checks if the current environment supports multiline output in line.
+ Notes:
+ This function checks if the current environment supports multiline output.
+ It returns True if any of the following conditions are met:
+
+ - The `jupyter_core` module is available (implying that the code is being run in a Jupyter notebook or JupyterLab).
+ - The width of the console is reported as 0 by `shutil.get_terminal_size()`, which can occur in some non-standard environments or configurations.
+ - The `PYCHARM_HOSTED` environment variable is set, which indicates that the code is being run in PyCharm's integrated console.
+
+ If none of these conditions are met, the function returns False.
+ """
if "jupyter_core" in sys.modules or shutil.get_terminal_size((0, 0)).columns == 0 or "PYCHARM_HOSTED" in os.environ:
return True
else:
@@ -39,10 +52,15 @@ def _is_jupyter() -> bool: # pragma: no cover
class ScreenStr:
"""
- A ScreenStr start with '\r' won't overflow, any string outside the screen width will be cut.
+ A class representing a string that can be displayed on a console screen.
+
+ Attributes:
+ content (str): The string content to be displayed on the screen.
+ leftoffset (int): The number of characters to shift the content to the left.
Notes:
- If output consolo support multiline(like pycharm or jupyter notebook) return, all string will be represented.
+ A ScreenStr starting with '\r' will not overflow, and any string longer than the screen width will be cut off.
+ If the console supports multiline output (like PyCharm or Jupyter notebook), all the string will be represented.
"""
t = 0
dt = 0.7
@@ -57,23 +75,33 @@ class ScreenStr:
multi_mode = support_multiline()
def __init__(self, content="", leftoffset=0) -> None:
+ """Initializes a new instance of the ScreenStr class."""
self.content = content
ScreenStr.left = leftoffset
def __repr__(self) -> str:
+ """Returns the string representation of the ScreenStr object."""
if ScreenStr.multi_mode:
return self.content
return self._screen_str()
+ def __len__(self) -> int:
+ """Returns the length of the string content."""
+ txt = self.content.encode("gbk", errors='ignore')
+ return len(txt)
+
def tostr(self):
+ """Returns the string content."""
return self.content
@classmethod
def set_speed(cls, dt: float = 0.05):
+ """Sets the speed of the text scrolling animation."""
cls.dt = dt
@classmethod
def deltatime(cls):
+ """Calculates the time elapsed since the last update."""
if cls.last == 0:
cls.last = time.time()
return 0
@@ -85,6 +113,7 @@ def deltatime(cls):
@classmethod
def cacu_offset_(cls, out_width):
+ """Calculates the offset for scrolling the text."""
delta = cls.deltatime()
cls.t += delta * cls.dt
@@ -102,11 +131,8 @@ def cacu_offset_(cls, out_width):
a = 1
- def __len__(self) -> int:
- txt = self.content.encode("gbk", errors='ignore')
- return len(txt)
-
def _decode_sub(self, txt, left, right):
+ """Decodes a part of a byte string to a Unicode string."""
try:
txt = txt[left:right].decode("gbk", errors='ignore')
except:
@@ -122,11 +148,13 @@ def _decode_sub(self, txt, left, right):
@staticmethod
def consolo_width():
+ """Returns the width of the console."""
width = get_consolo_width()
return width
@staticmethod
def split(txt, len):
+ """Splits a string into two parts."""
try:
return txt[:len], txt[len:]
except:
@@ -136,6 +164,7 @@ def split(txt, len):
return txt[:len - 1], txt[len - 1:]
def _screen_str(self, margin="..."):
+ """Returns the string content formatted for display on the screen."""
width = self.consolo_width()
txt = self.content.encode("gbk", errors='ignore').strip()
@@ -163,9 +192,36 @@ def _screen_str(self, margin="..."):
class inlinetqdm(tqdm):
+ """
+ A subclass of `tqdm` that formats progress bar updates as a single line.
+
+ This subclass provides two additional methods:
+
+ - `full_str`: Returns a formatted string representing the full progress bar,
+ including the progress bar itself and any additional information (such as elapsed time or estimated remaining time).
+
+ """
def full_str(self):
+ """
+ Returns a formatted string representing the full progress bar, including the progress bar itself and any additional information (such as elapsed time or estimated remaining time).
+
+ Args:
+ None
+
+ Returns:
+ str: A formatted string representing the full progress bar.
+ """
return self.format_meter(**self.format_dict)
def __str__(self):
+ """
+ Overrides the `__str__` method of the `tqdm` class to display the progress bar as a single-line string.
+
+ Args:
+ None
+
+ Returns:
+ str: A single-line string representation of the progress bar.
+ """
return ScreenStr(self.full_str())._screen_str()
diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py
new file mode 100644
index 00000000..5ed5e9f2
--- /dev/null
+++ b/src/lumo/utils/subprocess.py
@@ -0,0 +1,56 @@
+import os
+import subprocess
+import select
+import signal
+
+
+def run_command(command, cwd=None, env=None):
+ proc = subprocess.Popen(command,
+ cwd=cwd,
+ env=env,
+ shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ try:
+ while proc.poll() is None:
+ # Wait for output from the process
+ rlist, _, _ = select.select([proc.stdout, proc.stderr], [], [], 0.1)
+ for stream in rlist:
+ line = stream.readline().decode('utf-8')
+ if line:
+ print(line, end='')
+
+ # Read the remaining output
+ for stream in [proc.stdout, proc.stderr]:
+ while True:
+ line = stream.readline().decode('utf-8')
+ if not line:
+ break
+ print(line, end='')
+
+ # Get the return code of the process
+ return_code = proc.wait()
+
+ # Raise an exception if the process returned a non-zero return code
+ # if return_code != 0:
+ # raise subprocess.CalledProcessError(return_code, command)
+ except KeyboardInterrupt:
+ os.kill(proc.pid, signal.SIGINT)
+
+ while proc.poll() is None:
+ # Wait for output from the process
+ rlist, _, _ = select.select([proc.stdout, proc.stderr], [], [], 0.1)
+ for stream in rlist:
+ line = stream.readline().decode('utf-8')
+ if line:
+ print(line, end='')
+
+ # Read the remaining output
+ for stream in [proc.stdout, proc.stderr]:
+ while True:
+ line = stream.readline().decode('utf-8')
+ if not line:
+ break
+ print(line, end='')
+
+ # Get the return code of the process
+ return_code = proc.wait()
+ return return_code
diff --git a/tests/contrib/test_functional.py b/tests/contrib/test_functional.py
index 0c79233c..d40ae84d 100644
--- a/tests/contrib/test_functional.py
+++ b/tests/contrib/test_functional.py
@@ -219,3 +219,5 @@ def test_contrastive_loss():
cs = InstanceLoss(4, 0.7, 'cpu')
assert (cs.forward(a, b) - contrastive_loss(a, b, temperature=0.7, inbatch_neg=True)) ** 2 < 1e-10
+
+
\ No newline at end of file
diff --git a/tests/contrib/test_module.py b/tests/contrib/test_module.py
index e69de29b..92e8c663 100644
--- a/tests/contrib/test_module.py
+++ b/tests/contrib/test_module.py
@@ -0,0 +1,18 @@
+from lumo.contrib.module.memoty_bank import MemoryBank
+from accelerate import Accelerator
+import torch
+import pytest
+
+
+def test_memory_bank():
+ bank = MemoryBank()
+ bank.register('test', 32, 512)
+ res = [torch.rand(128, 32) for i in range(8)]
+ acce = Accelerator()
+ for r in res:
+ bank.push('test', r)
+
+ assert (bank['test'] == torch.cat(res[4:], dim=0)).all()
+ bank.requires_grad_(False)
+ # assert (bank['test'][:128] == res[-1]).all()
+ # assert (bank['test'][128:] == torch.cat(res[1:4], dim=0)).all()
diff --git a/tests/core/test_attr.py b/tests/core/test_attr.py
index 0d250769..76fe6880 100644
--- a/tests/core/test_attr.py
+++ b/tests/core/test_attr.py
@@ -1,7 +1,8 @@
-from lumo.core.attr import Attr as attr, set_item_iterative, get_item_iterative
import numpy as np
import torch
+from lumo.core.attr import Attr as attr, set_item_iterative
+
class NAttr(attr):
@@ -22,11 +23,18 @@ def get_res():
res.e = torch.tensor([2, 3, 4]).float()
res.f = np.array(2)
res.g = np.array([2, 3, 4])
+
return res
def test_replace():
res = get_res()
+ print(res)
+ assert (
+ str(res) == "Attr([('a', 1), ('nn', None), ('kk', Attr([('k', None)])), ('st', 'NAttr()'), "
+ "('b', [2, 3, 4]), ('c', Attr([('a', 1), ('b', [5, 6, 7]), ('c', Attr([('d', [8, 9])]))])), "
+ "('d', tensor(1.)), ('e', tensor([2., 3., 4.])), ('f', array(2)), ('g', array([2, 3, 4]))])"
+ )
res.update(a=6, b=[4, 5])
res['c.c.e.f'] = 5
assert res.a == 6
@@ -38,7 +46,7 @@ def test_replace():
def test_get_set():
- res = {}
+ res = attr()
set_item_iterative(res, ['a', 'b', 'c'], 4)
assert isinstance(res['a'], dict)
assert isinstance(res['a']['b'], dict)
diff --git a/tests/core/test_disk.py b/tests/core/test_disk.py
index c8ba42d2..a06bed89 100644
--- a/tests/core/test_disk.py
+++ b/tests/core/test_disk.py
@@ -1,3 +1,5 @@
+import os.path
+
from lumo.core.disk import Metrics, TableRow
import numpy as np
import tempfile
@@ -5,15 +7,10 @@
from lumo.utils import safe_io as IO
glob['metric_root'] = tempfile.mkdtemp()
-from lumo.utils.fmt import strftime
def test_table_row():
- row = TableRow('lumo-test', 'core', strftime())
-
- # test update
- row.update('a', 'b')
- assert row['a'] == 'b'
+ row = TableRow(os.path.join(tempfile.mkdtemp(), 'lumo-test'))
# test update_metric
## test max
@@ -38,6 +35,5 @@ def test_table_row():
# test storage
row.flush()
storage = IO.load_pkl(row.fpath)
- assert storage['a'] == row['a']
assert storage['metric']['acc'] == row['metric']['acc']
assert (storage['metric']['clsAcc'] == row['metric']['clsAcc']).all()
diff --git a/tests/core/test_interp.py b/tests/core/test_interp.py
index a1535c6c..671338f5 100644
--- a/tests/core/test_interp.py
+++ b/tests/core/test_interp.py
@@ -1,4 +1,5 @@
from lumo.core.interp import *
+import numpy as np
from torch.optim.sgd import SGD
from torch import nn
from lumo import BaseParams
@@ -58,3 +59,74 @@ def test_scheduler_is_attr():
assert abs(phcos.get(9.99) - 1) < 1e-5
assert phcos.get(10) == 0
assert phcos.get(15) == 0.5
+
+
+def test_period_linear():
+ start = 0
+ end = 10
+ period = 5
+ left = 0
+ constant = False
+ cur = 2
+
+ expected = 4.0
+ result = PeriodLinear.interp(cur, start, end, left, period, constant)
+ assert result == expected
+
+
+def test_power_decay():
+ start = 1
+ decay_steps = 5
+ decay_rate = 0.5
+ end = None
+ cur = 10
+
+ expected = 0.25
+ power_decay = PowerDecay(start, decay_steps, decay_rate, end)
+ result = power_decay(cur)
+ assert result == expected
+
+
+def test_power_decay2():
+ start = 1
+ schedules = [5, 10]
+ gammas = [0.5, 0.2]
+ cur = 12
+
+ expected = 0.1
+ power_decay2 = PowerDecay2(start, schedules, gammas)
+ result = power_decay2(cur)
+ assert result == expected
+
+
+# def test_ABCContinuous():
+# # Test ABCContinuous class
+# abc = ABCContinuous(start=1, end=2, left=0, right=10)
+# assert abc(0) == 1
+# assert abc(10) == 2
+# assert abc(5) == abc.interp(5, start=1, end=2, left=0, right=10)
+
+
+def test_Exp():
+ # Test Exp class
+ exp = Exp(start=1, end=2, left=0, right=10)
+ assert np.isclose(exp(0), 1)
+ assert np.isclose(exp(10), 2)
+ assert np.isclose(exp(5), 1.078716025, rtol=1e-5)
+ assert np.isclose(exp(8), 1.366531851)
+ assert np.isclose(exp(9.5), 1.7784638857)
+
+
+def test_Log():
+ # Test Log class
+ log = Log(start=1, end=2, left=0, right=10)
+ assert np.isclose(log(0), 1)
+ assert np.isclose(log(10), 2)
+ assert np.isclose(log(5), 1.921283974)
+
+
+def test_Constant():
+ # Test Constant class
+ const = Constant(value=0.5)
+ assert const(0) == 0.5
+ assert const(10) == 0.5
diff --git a/tests/core/test_meter.py b/tests/core/test_meter.py
index be5f8c14..aec47f09 100644
--- a/tests/core/test_meter.py
+++ b/tests/core/test_meter.py
@@ -1,4 +1,53 @@
-from lumo.core.meter import ReduceItem
+from collections import OrderedDict
+
+import pytest
+
+from lumo.core.meter import ReduceItem, Meter
+
+
+def test_meter():
+ m = Meter()
+ m['loss'] = 0.5
+ m['accuracy'] = 0.8
+
+ # test __getitem__ method
+ assert m['loss'] == 0.5
+
+ # test __setitem__ method
+ m['loss'] = 0.2
+ assert m['loss'] == 0.2
+
+ # test __repr__ method
+ assert repr(m) == 'loss: 0.2 | accuracy: 0.8'
+
+ # test keys method
+ assert set(m.keys()) == {'loss', 'accuracy'}
+
+ # test todict method
+ assert m.todict() == {'loss': 0.2, 'accuracy': 0.8}
+
+ # test sorted method
+ sorted_m = m.sorted()
+ assert isinstance(sorted_m, Meter)
+ assert set(sorted_m.keys()) == {'accuracy', 'loss'}
+ assert repr(sorted_m) == 'accuracy: 0.8 | loss: 0.2'
+
+ # test update method
+ m.update({'loss': 0.1, 'precision': 0.9})
+ assert set(m.keys()) == {'loss', 'accuracy', 'precision'}
+ assert m.todict() == {'loss': 0.1, 'accuracy': 0.8, 'precision': 0.9}
+
+ # test from_dict method
+ m2 = Meter.from_dict(OrderedDict([('loss', 0.1), ('accuracy', 0.8), ('precision', 0.9)]))
+ assert set(m2.keys()) == {'loss', 'accuracy', 'precision'}
+ assert m2.todict() == {'loss': 0.1, 'accuracy': 0.8, 'precision': 0.9}
+
+ # test scalar_items method
+ m3 = Meter()
+ m3['loss'] = 0.5
+ m3['accuracy'] = '80%'
+ m3['precision'] = [0.9, 0.8]
+ assert set(m3.scalar_items()) == {('loss', 0.5), ('accuracy', '80%')}
def test_avg_item_mean():
diff --git a/tests/core/test_params.py b/tests/core/test_params.py
index 5510a78f..fa449f74 100644
--- a/tests/core/test_params.py
+++ b/tests/core/test_params.py
@@ -1,9 +1,8 @@
import json
import tempfile
-from omegaconf import DictConfig
-from lumo.core.raises import BoundCheckError
from lumo import BaseParams
+from lumo.core.raises import BoundCheckError
def test_params():
@@ -48,11 +47,13 @@ def get_res():
def test_argv():
params = get_res()
- params.from_args(['--a', '1', '--d.c.d=2'])
+ params.from_args(['--a', '1', '--d.c.d=2', '--kk.c=3'])
+ print(params)
assert params.a == 1
assert params.d.c.d == 2
- assert isinstance(params.kk, DictConfig)
- assert isinstance(params.d.c, DictConfig)
+ assert isinstance(params.kk, MyParams)
+ assert params.kk.c == 3
+ assert isinstance(params.d.c, BaseParams)
def test_dict():
@@ -67,10 +68,22 @@ def test_json():
fn = tempfile.mktemp()
with open(fn, 'w') as w:
json.dump({'c': {'a': 2}}, w)
+ res.c.a = 4
res.from_json(fn)
assert res.c.a == 2
+def test_yaml():
+ res = get_res()
+ fn = tempfile.mktemp('.yaml')
+ res.a = 3
+ res.to_yaml(fn)
+ res.from_args([f'--config={fn}', '--a=2'])
+ assert res.a == 2
+ res.from_args([f'--config={fn}'])
+ assert res.a == 3
+
+
def test_copy():
res = get_res()
copy = res.copy()
diff --git a/tests/core/test_record.py b/tests/core/test_record.py
new file mode 100644
index 00000000..3e517294
--- /dev/null
+++ b/tests/core/test_record.py
@@ -0,0 +1,49 @@
+import pytest
+
+from lumo import Record, Meter
+
+
+@pytest.fixture
+def record():
+ return Record(stage='test')
+
+
+def test_stage(record):
+ assert record.stage == 'test'
+
+
+def test_record(record):
+ for i in range(10):
+ m = Meter()
+ m.sum.C = 512
+ record.record(m)
+ record.record({'loss': 0.5, 'accuracy': 0.8})
+ assert record._agg['loss'].res == 0.5
+ assert record._agg['accuracy'].res == 0.8
+
+
+def test_record_meter(record):
+ for i in range(10):
+ m = Meter()
+ m.sum.C = 512
+ record.record(m)
+ assert record.agg()['C'] == 512 * 10
+ # assert record._agg['accuracy'].res == 0.8
+
+
+def test_clear(record):
+ record.record({'loss': 0.5, 'accuracy': 0.8})
+ record.clear()
+ assert len(record._agg) == 0
+ assert len(record._cache) == 0
+
+
+def test_flush(record):
+ record.record({'loss': 0.5, 'accuracy': 0.8})
+ record.flush()
+ assert len(record._cache) == 0
+
+
+def test_str(record):
+ record.record({'loss': 0.5, 'accuracy': 0.8})
+ assert str(record) == 'loss=0.5, accuracy=0.8'
diff --git a/tests/exp/test_watcher.py b/tests/exp/test_watcher.py
new file mode 100644
index 00000000..35ad087a
--- /dev/null
+++ b/tests/exp/test_watcher.py
@@ -0,0 +1,27 @@
+from lumo import Trainer, TrainerParams
+from lumo.exp.watch import Watcher
+from lumo.proc.config import debug_mode
+
+
+class MyTrainer(Trainer):
+ pass
+
+
+def trainer():
+ params = TrainerParams()
+ t = MyTrainer(params)
+ return t
+
+
+def test_exp():
+ debug_mode()
+ for i in range(10):
+ t = trainer()
+ t.train()
+ print(t.exp.heartbeat_fn)
+
+ w = Watcher()
+ df = w.load()
+ print(df.columns)
+ # print(sorted(list(df['test_name'])))
+ assert len(df) == 10
diff --git a/tests/proc/test_dist.py b/tests/proc/test_dist.py
new file mode 100644
index 00000000..1cda9268
--- /dev/null
+++ b/tests/proc/test_dist.py
@@ -0,0 +1,24 @@
+from lumo.proc.dist import *
+
+
+def test_local_rank(monkeypatch):
+ monkeypatch.setenv('LOCAL_RANK', '0')
+ assert local_rank() == 0
+
+ monkeypatch.setenv('LOCAL_RANK', '1')
+ assert local_rank() == 1
+
+
+def test_world_size(monkeypatch):
+ monkeypatch.setenv('WORLD_SIZE', '4')
+ assert world_size() == 4
+
+
+def test_is_dist(monkeypatch):
+ monkeypatch.setenv('LOCAL_RANK', '0')
+ assert is_dist() == True
+
+
+def test_is_main(monkeypatch):
+ monkeypatch.setenv('LOCAL_RANK', '0')
+ assert is_main() == True
diff --git a/tests/proc/test_proc.py b/tests/proc/test_proc.py
new file mode 100644
index 00000000..6335e475
--- /dev/null
+++ b/tests/proc/test_proc.py
@@ -0,0 +1,47 @@
+from lumo.proc.path import *
+
+
+def test_home():
+ assert home() == os.path.expanduser("~")
+
+
+# def test_cache_dir():
+# CACHE_ROOT = glob.get('cache_dir', None)
+# expected = CACHE_ROOT or os.path.join(home(), '.lumo/cache')
+# assert cache_dir() == expected
+
+
+def test_libhome():
+ LIBHOME = glob.get('home', None)
+ expected = LIBHOME or os.path.join(home(), '.lumo')
+ assert libhome() == expected
+
+
+def test_exproot():
+ EXP_ROOT = glob.get('exp_root', None)
+ expected = EXP_ROOT or os.path.join(libhome(), 'experiments')
+ assert exproot() == expected
+
+
+def test_progressroot():
+ PROGRESS_ROOT = glob.get('progress_root', None)
+ expected = PROGRESS_ROOT or os.path.join(cache_dir(), 'progress')
+ assert progressroot() == expected
+
+
+def test_blobroot():
+ BLOB_ROOT = glob.get('blob_root', None)
+ expected = BLOB_ROOT or os.path.join(libhome(), 'blob')
+ assert blobroot() == expected
+
+
+def test_metricroot():
+ METRIC_ROOT = glob.get('metric_root', None)
+ expected = METRIC_ROOT or os.path.join(libhome(), 'metrics')
+ assert metricroot() == expected
+
+
+def test_local_dir():
+ # it's difficult to test this function without a specific context
+ # since it depends on git_dir(), which is not included in the provided code
+ pass
diff --git a/tests/trainer/test_builder.py b/tests/trainer/test_builder.py
index 5e01ff16..071fd1f1 100644
--- a/tests/trainer/test_builder.py
+++ b/tests/trainer/test_builder.py
@@ -1,4 +1,4 @@
-from lumo import DatasetBuilder, DataLoaderSide
+from lumo import DatasetBuilder, DataLoaderSide, DataModule, ParamsType, TrainStage
def global_check(dic):
@@ -10,18 +10,56 @@ def global_check(dic):
def create_dataset_builder():
builder = (
DatasetBuilder()
+ .add_idx('id')
.add_input(name='xs', source=range(1000))
+ .add_input(name='axs', source=range(1000), transform=lambda x: x - 1)
.add_input(name='ys', source=range(1, 1001))
- .add_output(name='xs', outkey='xs1')
+ .add_output(name='xs', outkey='xs1', transform=lambda x: x + 1)
.add_output(name='xs', outkey='xs2')
+ .add_output(name='axs', outkey='xs3')
.add_output(name='ys', outkey='ys1')
- .set_output_transform('xs1', lambda x: x + 1)
.set_output_transform('ys1', lambda x: x - 1)
.add_global_transform(global_check)
)
return builder
+def test_iter_data():
+ builder = (
+ DatasetBuilder()
+ .add_idx('id')
+ .add_input('xs', iter(range(20)))
+ .add_input('ys', iter(range(20)))
+ .add_output('xs', 'xs1')
+ .add_output('xs', 'xs2', transform=lambda x: x + 1)
+ .add_output('ys', 'ys')
+ )
+ try:
+ builder[0]
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ len(builder)
+ assert False
+ except TypeError:
+ assert True
+
+ for i, sample in enumerate(builder):
+ assert isinstance(sample, dict)
+ assert sample['xs1'] == i
+ assert sample['xs2'] == i + 1
+ assert sample['ys'] == i
+ assert sample['id'] == i
+
+ builder.chain()
+ for i, (xs1, xs2, ys) in enumerate(builder):
+ assert xs1 == i
+ assert xs2 == i + 1
+ assert ys == i
+
+
def test_builder_base():
builder = create_dataset_builder()
@@ -39,12 +77,21 @@ def test_builder_base():
assert len(builder) == 1000
- sub_builder = builder.subset(range(500), copy=True)
+ sub_builder = builder.subset(range(20), copy=True)
assert len(builder) == 1000
- assert len(sub_builder) == 500
- assert sub_builder[499]['xs1'] == 500
- assert sub_builder[499]['ys1'] == 499
- assert sub_builder[499]['xs2'] == 499
+ assert builder != sub_builder
+ new_builder = builder.subset(range(500))
+ assert len(builder) == 500
+ assert builder == new_builder
+
+ assert len(sub_builder) == 20
+
+ for i, sample in enumerate(sub_builder):
+ assert sample['id'] == i
+ assert sample['xs1'] == i + 1
+ assert sample['ys1'] == i
+ assert sample['xs2'] == i
+ assert sample['xs3'] == i - 1
dic = sub_builder.inputs
assert 'xs' in dic
@@ -57,6 +104,14 @@ def test_builder_base():
str(sub_builder)
+ sub_builder.zip()
+ assert isinstance(sub_builder[0], dict)
+ sub_builder.chain()
+ assert isinstance(sub_builder[0], list)
+ sub_builder.item()
+ print(sub_builder[0])
+ # assert isinstance(sub_builder[0], dict)
+
def test_side():
sup = create_dataset_builder()
@@ -81,3 +136,26 @@ def test_side():
assert 'xs2' in sup
assert 'ys1' in sup
assert un['xs1'].shape[0] == 32
+
+
+class MyDataModule(DataModule):
+
+ def idataloader(self, params: ParamsType = None, stage: TrainStage = None):
+ super().idataloader(params, stage)
+ sup = create_dataset_builder()
+ un = create_dataset_builder()
+
+ dl = (
+ DataLoaderSide()
+ .add('sup', sup.DataLoader(batch_size=128, drop_last=True), cycle=True)
+ .add('un', un.DataLoader(batch_size=32, drop_last=True))
+ .zip()
+ )
+ self.regist_dataloader_with_stage(stage, dl)
+
+
+def test_dm_dataloader():
+ dm = MyDataModule()
+ loader = dm.train_dataloader
+ assert dm.train_dataset == loader.dataset
+ assert isinstance(dm.train_dataloader, DataLoaderSide)
diff --git a/tests/trainer/test_finder.py b/tests/trainer/test_finder.py
deleted file mode 100644
index 2588bae8..00000000
--- a/tests/trainer/test_finder.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import random
-
-from lumo import Trainer, ParamsType, TrainerParams, Experiment
-from lumo.exp import finder
-from lumo.proc.config import debug_mode
-
-
-class ATrainer(Trainer):
-
- def icallbacks(self, params: ParamsType):
- super().icallbacks(params)
-
-
-class BTrainer(Trainer):
-
- def icallbacks(self, params: ParamsType):
- super().icallbacks(params)
-
-
-def test_finder():
- debug_mode()
-
- for i in range(5):
- params = TrainerParams()
- params.epoch = i
- params.rnd = random.random()
- ATrainer(params).train()
- BTrainer(params).train()
-
- all_tests = finder.list_all()
- assert len(all_tests) == 2
- assert ATrainer.generate_exp_name() in all_tests
- assert BTrainer.generate_exp_name() in all_tests
- assert len(all_tests[ATrainer.generate_exp_name()]) == 5
- assert len(all_tests[BTrainer.generate_exp_name()]) == 5
-
- assert isinstance(all_tests[ATrainer.generate_exp_name()][0], Experiment)
- for exp in all_tests[ATrainer.generate_exp_name()]:
- params = TrainerParams().from_yaml(exp.properties['params.yaml'])
- assert params.hash() == exp.properties['params_hash']
- assert finder.find_path_from_test_name(exp.test_name) == exp.test_root
diff --git a/tests/trainer/test_saver.py b/tests/trainer/test_saver.py
index f7cd2202..766abcc4 100644
--- a/tests/trainer/test_saver.py
+++ b/tests/trainer/test_saver.py
@@ -1,3 +1,5 @@
+import os
+
from lumo.trainer.saver import Saver
import time
import shutil
@@ -17,10 +19,11 @@ def test_save_load():
else:
saver.save_checkpoint(i, {'step': i}, {'meta_step': i}, max_keep=max_keep, is_best=(i % 5) == 0)
+ print(os.listdir(save_root))
assert len(saver.list_models()) == (epoch)
state = saver.load_model(best_if_exist=True, with_meta=True)
- print(state)
+ # print(state)
assert state[0]['step'] == 5 and state[1]['meta_step'] == 5
state = saver.load_model(2)
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 35415bbb..38c015cd 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -1,20 +1,18 @@
-from typing import Union, Optional, Sequence, Mapping, Any
-
-from lumo.proc.config import debug_mode
-from lumo.utils.repository import git_dir
import os
+from typing import Union, Optional, Sequence, Mapping
-import tempfile
-
+import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from lumo import ParamsType, TrainerParams
-from lumo import Trainer, DataModule, Meter, TrainStage, MetricType, Record, DatasetBuilder
+from lumo import Trainer, DataModule, TrainStage, MetricType, Record, DatasetBuilder
from lumo.data.loader import DataLoaderType
-from lumo.trainer import callbacks
from lumo.proc import glob
+from lumo.proc.config import debug_mode
+from lumo.trainer import callbacks
+from lumo.utils.repository import git_dir
def create_dataset_builder():
@@ -97,6 +95,8 @@ def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, limit_s
limit_global_steps=None) -> Record:
assert self.context == 'train_epoch'
assert self.contexts[-2] == 'train'
+ if params.get('raise_exp', False):
+ raise ValueError('raised by test')
return super().train_epoch(loader, params, limit_step, limit_global_steps)
def remove_callback(self, cur):
@@ -141,19 +141,24 @@ def test_trainer():
params.epoch = 2
debug_mode()
+ dm = MyDataModule()
# glob['HOOK_FINALREPORT'] = False
- trainer = CBTrainer(params, dm=MyDataModule())
+ trainer = CBTrainer(params, dm=dm)
trainer.train()
trainer.test()
trainer.evaluate()
trainer.logger.info(trainer.lf.functions)
trainer.exp.end()
+ assert dm.train_dataset == dm._parse_dataset(dm.train_dataloader)
+ assert dm.val_dataset == dm._parse_dataset(dm.val_dataloader)
+ assert dm.test_dataset == dm._parse_dataset(dm.test_dataloader)
+
# test trainer experiment
exp = trainer.exp
- assert exp.exp_root == os.path.join(glob['exp_root'], trainer.generate_exp_name())
- assert exp.lib_root == glob['home']
- assert exp.blob_root == os.path.join(glob['blob_root'], trainer.generate_exp_name(), exp.test_name)
+ assert exp.exp_dir == os.path.join(glob['exp_root'], trainer.generate_exp_name())
+ assert exp.info_dir == os.path.join(glob['exp_root'], trainer.generate_exp_name(), exp.test_name)
+ assert exp.blob_dir == os.path.join(glob['blob_root'], trainer.generate_exp_name(), exp.test_name)
assert exp.project_root == git_dir()
# how to test writer?
_ = trainer.safe_writer
@@ -162,5 +167,49 @@ def test_trainer():
raise AssertionError(str(trainer.callback_function - trainer.lf.functions))
+def test_trainer_params():
+ params = TrainerParams()
+ params.optim = params.OPTIM.create_optim('SGD', lr=0.9)
+ params.optim.lr = 3
+ print(type(params.optim))
+ print(params.optim)
+ module = nn.Linear(10, 10)
+ optim = params.optim.build(module.parameters())
+
+
+class MyTrainer(Trainer):
+ pass
+
+
+def test_trainer_state_dict():
+ trainer = MyTrainer(TrainerParams())
+ device_a = trainer.device_a = torch.device('cpu')
+ ndarray_a = trainer.ndarray_a = np.array([1, 2, 3])
+ tensor_a = trainer.tensor_a = torch.tensor([1, 2, 3])
+ module = trainer.module = nn.Linear(10, 10)
+ optim_a = trainer.optim_a = TrainerParams.OPTIM.create_optim('SGD', lr=0.9).build(trainer.module.parameters())
+
+ state_dict = trainer.state_dict()
+ # assert state_dict['devices']['device_a'] == trainer.device_a
+ assert state_dict['optims']['optim_a'] == trainer.optim_a.state_dict()
+ assert all([(i == j).all()
+ for i, j in zip(state_dict['models']['module'].values(), trainer.module.state_dict().values())])
+ assert (state_dict['thtensor']['tensor_a'] == trainer.tensor_a).all()
+ assert (state_dict['nptensor']['ndarray_a'] == trainer.ndarray_a).all()
+
+ fn = trainer.save_state_dict()
+ trainer.ndarray_a = np.array([3, 2, 1])
+ trainer.tensor_a = torch.tensor([3, 2, 1])
+ trainer.module = nn.Linear(10, 10)
+ trainer.optim_a = TrainerParams.OPTIM.create_optim('SGD', lr=0.9).build(trainer.module.parameters())
+
+ trainer.load_state_dict(torch.load(fn, map_location='cpu'))
+ assert state_dict['optims']['optim_a'] == optim_a.state_dict()
+ assert all([(i == j).all()
+ for i, j in zip(state_dict['models']['module'].values(), module.state_dict().values())])
+ assert (state_dict['thtensor']['tensor_a'] == tensor_a).all()
+ assert (state_dict['nptensor']['ndarray_a'] == ndarray_a).all()
+
+
if __name__ == '__main__':
test_trainer()
diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py
new file mode 100644
index 00000000..fc52be61
--- /dev/null
+++ b/tests/utils/test_io.py
@@ -0,0 +1,101 @@
+import yaml
+from lumo.utils.safe_io import *
+
+
+def test_dump_json(tmpdir):
+ # Test that dump_json creates a valid JSON file
+ obj = {"a": 1, "b": 2}
+ fn = os.path.join(str(tmpdir), "test.json")
+ dump_json(obj, fn)
+ with open(fn, "r") as f:
+ data = json.load(f)
+ assert data == obj
+
+
+def test_dump_yaml(tmpdir):
+ # Test that dump_yaml creates a valid YAML file
+ obj = {"a": 1, "b": 2}
+ fn = os.path.join(str(tmpdir), "test.yaml")
+ dump_yaml(obj, fn)
+ with open(fn, "r") as f:
+ data = yaml.safe_load(f)
+ assert data == obj
+
+
+def test_dump_state_dict(tmpdir):
+ # Test that dump_state_dict creates a valid state dict file
+ obj = {"a": torch.randn(3, 3)}
+ fn = os.path.join(str(tmpdir), "test.pt")
+ dump_state_dict(obj, fn)
+ data = torch.load(fn)
+ assert (data['a'] == obj['a']).all()
+
+
+def test_load_json(tmpdir):
+ # Test that load_json reads a valid JSON file
+ obj = {"a": 1, "b": 2}
+ fn = os.path.join(str(tmpdir), "test.json")
+ with open(fn, "w") as f:
+ json.dump(obj, f)
+ data = load_json(fn)
+ assert data == obj
+
+
+def test_load_yaml(tmpdir):
+ # Test that load_yaml reads a valid YAML file
+ obj = {"a": 1, "b": 2}
+ fn = os.path.join(str(tmpdir), "test.yaml")
+ with open(fn, "w") as f:
+ yaml.safe_dump(obj, f)
+ data = load_yaml(fn)
+ assert data == obj
+
+
+def test_load_state_dict(tmpdir):
+ # Test that load_state_dict reads a valid state dict file
+ obj = {"a": torch.randn(3, 3)}
+ fn = os.path.join(str(tmpdir), "test.pt")
+ torch.save(obj, fn)
+ data = load_state_dict(fn)
+ assert (data['a'] == obj['a']).all()
+
+
+def test_dump_text(tmpdir):
+ # Test that dump_text creates a valid text file
+ string = "hello\nworld"
+ fn = os.path.join(str(tmpdir), "test.txt")
+ dump_text(string, fn)
+ with open(fn, "r") as f:
+ data = f.read()
+ assert data == string
+
+
+def test_load_text(tmpdir):
+ # Test that load_text reads a valid text file
+ string = "hello\nworld"
+ fn = os.path.join(str(tmpdir), "test.txt")
+ with open(fn, "w") as f:
+ f.write(string)
+ data = load_text(fn)
+ assert data == string
+
+
+def test_dump_pkl(tmpdir):
+ # Test that dump_pkl creates a valid pickle file
+ obj = {"a": 1, "b": 2}
+ fn = os.path.join(str(tmpdir), "test.pkl")
+ dump_pkl(obj, fn)
+ data = load_pkl(fn)
+ assert data == obj
+
+
+def test_cached(tmpdir):
+ # Test that cached context manager works correctly
+ with cached(os.path.join(str(tmpdir), "test.txt")) as cache_fn:
+ # Write some data to the cache file
+ with open(cache_fn, "w") as f:
+ f.write("hello")
+ # Check that the cache file exists
+ assert os.path.isfile(cache_fn)
+ # Check that the cache file is deleted after exiting the context manager
+ assert not os.path.isfile(cache_fn)