From 0ca4fb448c798b286bf093b22ef903836c0e99c9 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 13:12:26 +0800 Subject: [PATCH 01/99] Add docstring --- src/lumo/data/collate.py | 7 ++++++- src/lumo/data/datamodule.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/lumo/data/collate.py b/src/lumo/data/collate.py index 56eb985..e56fbcb 100644 --- a/src/lumo/data/collate.py +++ b/src/lumo/data/collate.py @@ -1,5 +1,5 @@ """ - +collate.py provided some useful classes and functions to act as the collect_fn in dataloader """ from typing import Any, Mapping, Sequence from lumo.core.params import ParamsType @@ -9,6 +9,11 @@ class CollateBase: + """ + Base Collate Fn, a common abstract of collate_fn + + DataLoader(collate_fn=YourCollateClass(...)) + """ @classmethod def from_collate(cls, collate_fn, params: ParamsType = None): diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py index e858b1a..98bb5d2 100644 --- a/src/lumo/data/datamodule.py +++ b/src/lumo/data/datamodule.py @@ -8,6 +8,10 @@ class DataModule: + """ + Used in `Trainer` to easy access `DataLoader`s for different stage(train/test/eval/others). + """ + def __init__(self, params: ParamsType = None): self._prop = {} self.params = params From 392e81b591462d65dc872125b6db205b51407e8d Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 13:15:32 +0800 Subject: [PATCH 02/99] remove __main__ for data module --- src/lumo/data/__main__.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 src/lumo/data/__main__.py diff --git a/src/lumo/data/__main__.py b/src/lumo/data/__main__.py deleted file mode 100644 index 1ab0e4a..0000000 --- a/src/lumo/data/__main__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - -cd {{repo_root}} -python -m lumo.data supervised - -python -m lumo.data selfsupervised - -python -m lumo.data semisupervised -from datasets.objdect - -python -m lumo.data https://github.com/example/dataset objdect - - -from datasets.objdect -""" From 23637abc9066e1e7ccc4b898b4af1b04b9625fcf Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 15:27:03 +0800 Subject: [PATCH 03/99] Add docstring for utils module, powered by chatgpt --- src/lumo/utils/ast.py | 11 ++++++ src/lumo/utils/exithook.py | 5 ++- src/lumo/utils/filelock.py | 75 +++++++++++++----------------------- src/lumo/utils/filelock2.py | 73 ----------------------------------- src/lumo/utils/fmt.py | 23 +++++++++++ src/lumo/utils/hash.py | 8 ---- src/lumo/utils/random.py | 39 +++++++++++++++++-- src/lumo/utils/repository.py | 38 ++++++++++++++++-- src/lumo/utils/safe_io.py | 41 +++++++++++++++++++- src/lumo/utils/screen.py | 22 +++++++++++ 10 files changed, 194 insertions(+), 141 deletions(-) delete mode 100644 src/lumo/utils/filelock2.py delete mode 100644 src/lumo/utils/hash.py diff --git a/src/lumo/utils/ast.py b/src/lumo/utils/ast.py index 9798ee8..b23c4b0 100644 --- a/src/lumo/utils/ast.py +++ b/src/lumo/utils/ast.py @@ -3,6 +3,17 @@ def analyse_module_dependency(module, mem=None, root=None): + """ + Recursively analyse the dependencies of a module and return a dictionary of modules and their associated files. + + Args: + module (module): The module to analyse. + mem (dict, optional): A dictionary of previously analysed modules and their associated files. + root (str, optional): The root directory to use as a reference when determining whether a file is a dependency. + + Returns: + dict: A dictionary of modules and their associated files. + """ if mem is None: mem = {} diff --git a/src/lumo/utils/exithook.py b/src/lumo/utils/exithook.py index 4d18702..2c906df 100644 --- a/src/lumo/utils/exithook.py +++ b/src/lumo/utils/exithook.py @@ -1,15 +1,17 @@ """ -Do something you want before the programe exit. +This module provides functions to wrap or replace the default exception handler of Python's sys module. """ import sys from functools import wraps def replace(func): + """Replace the default exception handler with the provided function.""" sys.excepthook = func def wrap_after(func): + """Wrap the default exception handler, executing the provided function after it.""" old = sys.excepthook def outer(fun): @@ -24,6 +26,7 @@ def inner(*args, **kwargs): def wrap_before(func): + """Wrap the default exception handler, executing the provided function before it.""" old = sys.excepthook def outer(fun): diff --git a/src/lumo/utils/filelock.py b/src/lumo/utils/filelock.py index 8f75202..6c20b23 100644 --- a/src/lumo/utils/filelock.py +++ b/src/lumo/utils/filelock.py @@ -1,58 +1,35 @@ -import time - +from filelock import Timeout, FileLock import os -import random - -from lumo.utils.exithook import wrap_before class Lock: - def __init__(self, name, sleep=1): - from lumo.proc.path import cache_dir - self.file = os.path.join(cache_dir(), f"LUMO_LOCK_{name}") - self.sleep = sleep - wrap_before(self.clear) + """ + A class for obtaining and releasing file-based locks using FileLock. - def clear(self, *_, **__): - self.release() + Args: + name (str): The name of the lock. - def abtain(self): - mulp = 1 - while True: - mulp += 1 - if mulp > 10: - raise TimeoutError(f'Can not abtain resource of {self.file}') - - while os.path.exists(self.file): - time.sleep(random.randint(mulp, mulp ** 2)) - mulp += 1 - if mulp > 10: - raise TimeoutError(f'Can not abtain resource of {self.file}') - - while True: - flag = f'{os.getpid()}' - with open(self.file, 'w') as w: - w.write(flag) - - if os.path.exists(self.file): - with open(self.file, 'r') as r: - lock_flag = r.read() - if flag == lock_flag: - return True - mulp += 1 - time.sleep(random.randint(mulp, mulp ** 2)) - if mulp > 10: - raise TimeoutError(f'Can not abtain resource of {self.file}') + Attributes: + fn (str): The file path of the lock file. + lock (FileLock): The FileLock object used for obtaining and releasing the lock. - def release(self): - try: - os.remove(self.file) - except: - pass - return True + Example: + lock = Lock('my_lock') + lock.abtain() + # critical section + lock.release() + """ - def __enter__(self): - self.abtain() + def __init__(self, name): + """Initialize the lock file path and FileLock object""" + from lumo.proc.path import cache_dir + self.fn = os.path.join(cache_dir(), f"LUMO_LOCK_{name}") + self.lock = FileLock(self.fn) + + def abtain(self): + """Acquire the lock""" + self.lock.acquire() - def __exit__(self, exc_type, exc_val, exc_tb): - self.release() + def release(self): + """Release the lock""" + self.lock.release() diff --git a/src/lumo/utils/filelock2.py b/src/lumo/utils/filelock2.py deleted file mode 100644 index f18916c..0000000 --- a/src/lumo/utils/filelock2.py +++ /dev/null @@ -1,73 +0,0 @@ -import time - -import os -import random -from joblib import hash -import mmap -from lumo.utils.exithook import wrap_before - - -class Lock: - def __init__(self, name, sleep=1, group=None): - from lumo.proc.path import cache_dir - if group is None: - group = "" - - self.size = 10000 - self.flag = f'{os.getpid()}' - self.flagsize = len(self.flag) - self.pos = int(hash(name), 16) % (self.size - len(self.flag)) - self.file = os.path.join(cache_dir(), f'LUMO_LOCK_V3_{group}') - if not os.path.exists(self.file): - with open(self.file, 'w') as w: - w.write('0' * self.size) - - self.sleep = sleep - self.fhdl = None - wrap_before(self.clear) - - def clear(self, *_, **__): - self.release() - - def abtain(self): - mulp = 0 - self.fhdl = r = open(self.file, 'r+b') - mm = mmap.mmap(r.fileno(), 0, access=mmap.ACCESS_WRITE) - flag = f'{os.getpid()}' - while True: - mulp += 1 - if mulp > 10: - raise TimeoutError(f'Can not abtain resource of {self.file}') - - if mm[self.pos:self.pos + len(flag)].decode() != '0' * len(flag): - time.sleep(random.randint(mulp, mulp ** 2)) - continue - - mm[self.pos:self.pos + len(flag)] = flag.encode() - mm.flush() - - if mm[self.pos:self.pos + len(flag)].decode() != flag: - mm.close() - time.sleep(random.randint(mulp, mulp ** 2)) - continue - - return True - - def release(self): - if self.fhdl is None: - return - - r = self.fhdl - mm = mmap.mmap(r.fileno(), 0, access=mmap.ACCESS_WRITE) - flag = b'0' * self.flagsize - mm[self.pos:self.pos + len(flag)] = flag - mm.flush() - r.close() - self.fhdl = None - return True - - def __enter__(self): - self.abtain() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.release() diff --git a/src/lumo/utils/fmt.py b/src/lumo/utils/fmt.py index 480c812..5c8458c 100644 --- a/src/lumo/utils/fmt.py +++ b/src/lumo/utils/fmt.py @@ -11,18 +11,21 @@ def to_ndarray(item): + """Convert a PyTorch tensor or any other array-like object to a NumPy ndarray.""" if isinstance(item, torch.Tensor): item = item.detach().cpu() return np.array(item) def detach(item): + """Detach a PyTorch tensor and convert it to a NumPy ndarray.""" if isinstance(item, torch.Tensor): item = item.detach().cpu().numpy() return item def validate_scalar_shape(ndarray, name=''): + """Validate that a given numpy array is a scalar.""" if ndarray.ndim != 0: raise ValueError( "Expected scalar value for %r but got %r" % (name, ndarray) @@ -31,6 +34,7 @@ def validate_scalar_shape(ndarray, name=''): def is_scalar(ndarray: np.ndarray): + """Check whether a numpy array is a scalar.""" return ndarray.size == 1 @@ -42,6 +46,7 @@ def strftime(fmt='%y-%m-%d-%H%M%S', dateobj: datetime = None): def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None): + """Convert a string to a datetime object using the specified format.""" return datetime.strptime(datestr, fmt) @@ -52,12 +57,30 @@ def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None): def to_filename(basename): + """Replace invalid characters in a basename with an underscore.""" return re.sub(_invalid_fc, '_', basename) def can_be_filename(basename): + """Check whether a basename can be converted to a valid filename.""" return re.search(_invalid_fc, basename) is None def indent_print(text, indent=' '): + """Prints the specified text with a given indentation.""" print(textwrap.indent(text, indent)) + + +def format_second(sec: int) -> str: + """Formats a duration given in seconds into a human-readable string.""" + sec, ms = divmod(sec, 1) + if sec > 60: + min, sec = divmod(sec, 60) + if min > 60: + hour, min = divmod(min, 60) + fmt = "{}h{}m{}s".format(hour, min, int(sec)) + else: + fmt = "{}m{}s".format(min, int(sec)) + else: + fmt = "{}s".format(int(sec)) + return fmt diff --git a/src/lumo/utils/hash.py b/src/lumo/utils/hash.py deleted file mode 100644 index 75f173c..0000000 --- a/src/lumo/utils/hash.py +++ /dev/null @@ -1,8 +0,0 @@ -from hashlib import md5 - - -def hash_iter(*object: str): - hasher = md5() - for i in object: - hasher.update(i.encode()) - return hasher.hexdigest() diff --git a/src/lumo/utils/random.py b/src/lumo/utils/random.py index 96accec..f6a0111 100644 --- a/src/lumo/utils/random.py +++ b/src/lumo/utils/random.py @@ -11,10 +11,28 @@ def int_time(): + """ + Get the current time as an integer. + + Returns: + int: The current time as an integer. + """ return int(str(time.time()).split(".")[-1]) def hashseed(hashitem: Union[int, str]): + """ + Generate a hash seed from a given integer or string. + + Args: + hashitem (Union[int, str]): The integer or string to generate the hash seed from. + + Returns: + int: The hash seed. + + Raises: + AssertionError: If the given `hashitem` is not an integer or a string. + """ if not isinstance(hashitem, (int, str)): raise AssertionError() @@ -52,6 +70,9 @@ def fix_seed(seed=10, cuda=True): def fix_cuda(): + """ + Set deterministic and reproducible configuration for CUDA. + """ if torch.cuda.is_available(): torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True @@ -59,6 +80,17 @@ def fix_cuda(): def get_state(cuda=True): + """ + Get the current state of the random number generators. + + Args: + cuda (bool): Whether to get the CUDA state if PyTorch is using CUDA. + + Returns: + dict: A dictionary containing the current states of the random number generators for + numpy, pytorch, pytorch.cuda, and python's built-in `random` module. + + """ return { "numpy": np.random.get_state(), "torch": torch.random.get_rng_state(), @@ -69,12 +101,11 @@ def get_state(cuda=True): def set_state(state_dict, cuda=True): """ - Set random state of built-in random, numpy, torch, torch.cuda + Set the random state for NumPy, PyTorch, PyTorch CUDA, and Python's built-in `random` module. Args: - state_dict: a dict got from `lumo.utils.random.get_state()` - - Returns: + state_dict (dict): A dictionary containing the desired states of the random number generators for NumPy, PyTorch, PyTorch CUDA, and Python's built-in `random` module. + cuda (bool): Whether to set the CUDA state if PyTorch is using CUDA. """ random.setstate(state_dict["random"]) diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index d65378f..0bebbef 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -8,10 +8,19 @@ import git from git import Repo, Commit from joblib import hash -from .filelock2 import Lock +from .filelock import Lock def dev_branch(): + """ + Returns the value of the 'dev_branch' key from the global configuration dictionary 'glob' in the 'lumo.proc.config' module. + + If the key is not present in the dictionary, returns the default value of 'lumo_experiments'. + + Returns: + str: The value of the 'dev_branch' key from the global configuration dictionary. By default is `lumo_experiments`. + + """ from lumo.proc.config import glob return glob.get('dev_branch', 'lumo_experiments') @@ -21,10 +30,26 @@ def dev_branch(): class branch: """ - 用于配合上下文管理切换 git branch + A context manager class for switching git branches in a given repository. + + Example usage: + with branch(repo, branch): + repo.index.commit('...') + + Args: + repo (Repo): The repository object for which the branch will be switched. + branch (str): The name of the branch to switch to. + + + Notes: + This class provides a context manager that switches the current branch + to the given branch when entering the context and switches back to the original branch when exiting the context. + + If the given branch does not exist in the repository, it will be created. + + A lock is obtained on the repository to ensure that + only one instance of this class can switch branches at a time for a given repository. - with branch(repo, branch): - repo.index.commit('...') """ def __init__(self, repo: Repo, branch: str): @@ -60,6 +85,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): def check_have_commit(repo): + """ + Checks if the given repository has any commits. + + If there are no commits, creates an initial commit that adds all files in the repository and has the message "initial commit". + """ if len(repo.heads) == 0: repo.git.add('.') repo.index.commit('initial commit') diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py index b67085f..c0d4684 100644 --- a/src/lumo/utils/safe_io.py +++ b/src/lumo/utils/safe_io.py @@ -17,28 +17,48 @@ def dump_json(obj, fn): + """ + Dumps the given object to a JSON file at the given file path. + + Args: + obj: The object to be dumped to JSON. + fn (str): The file path to which the JSON data will be written. + + Notes: + The JSON data will be written with an indentation of 2 spaces. + """ with open(fn, 'w', encoding='utf-8') as w: json.dump(obj, w, indent=2) def dump_yaml(obj, fn): + """ + Dumps the given object to a YAML file at the given file path. + + Args: + obj: The object to be dumped to YAML. + fn (str): The file path to which the YAML data will be written. + + Notes: + The YAML data will be written with default formatting options. + """ import yaml with open(fn, 'w', encoding='utf-8') as w: yaml.safe_dump(obj, w) - return fn def dump_state_dict(obj, fn): torch.save(obj, fn) - return fn def load_json(fn): + """Loads JSON data from the given file path and returns the resulting object.""" with open(fn, 'r', encoding='utf-8') as r: return json.load(r) def load_yaml(fn): + """Loads YAML data from the given file path and returns the resulting object.""" import yaml with open(fn, 'r', encoding='utf-8') as r: return yaml.safe_load(r) @@ -50,6 +70,7 @@ def load_state_dict(fn: str, map_location='cpu'): def load_text(fn): + """Loads text data from the given file path and returns it as a single string.""" if not os.path.exists(fn): return '' with open(fn, 'r', encoding='utf-8') as r: @@ -99,6 +120,22 @@ def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"): @contextmanager def cached(fn): + """ + A context manager that caches the output of a computation to a file. + + Args: + fn (str): The file path to which the cached data will be written. + + Yields: + str: The file path of the cache file. + + Examples: + + with cached('a.txt') as cache_fn: + with open(cache_fn, 'w') as w: + w.write('123') + + """ import shutil cache_fn = f'{fn}.lumo_cache' try: diff --git a/src/lumo/utils/screen.py b/src/lumo/utils/screen.py index fb41dc4..8c01622 100644 --- a/src/lumo/utils/screen.py +++ b/src/lumo/utils/screen.py @@ -10,10 +10,23 @@ def get_consolo_width(): + """Returns the width of the current console window""" return shutil.get_terminal_size().columns - 1 # -1 for windows consolo def support_multiline(): + """ + Checks if the current environment supports multiline output in line. + Notes: + This function checks if the current environment supports multiline output. + It returns True if any of the following conditions are met: + + - The `jupyter_core` module is available (implying that the code is being run in a Jupyter notebook or JupyterLab). + - The width of the console is reported as 0 by `shutil.get_terminal_size()`, which can occur in some non-standard environments or configurations. + - The `PYCHARM_HOSTED` environment variable is set, which indicates that the code is being run in PyCharm's integrated console. + + If none of these conditions are met, the function returns False. + """ if "jupyter_core" in sys.modules or shutil.get_terminal_size((0, 0)).columns == 0 or "PYCHARM_HOSTED" in os.environ: return True else: @@ -163,6 +176,15 @@ def _screen_str(self, margin="..."): class inlinetqdm(tqdm): + """ + A subclass of `tqdm` that formats progress bar updates as a single line. + + This subclass provides two additional methods: + + - `full_str`: Returns a formatted string representing the full progress bar, + including the progress bar itself and any additional information (such as elapsed time or estimated remaining time). + + """ def full_str(self): return self.format_meter(**self.format_dict) From 23a8ce35fd806fe2f1b423f24821fff0f7b1ae5d Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 15:27:25 +0800 Subject: [PATCH 04/99] hashed by joblib --- src/lumo/exp/exphook.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index 41a21b5..669385c 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -121,7 +121,6 @@ def on_start(self, exp: Experiment, *args, **kwargs): from lumo.utils.repository import git_enable, git_commit, git_dir from lumo.utils.ast import analyse_module_dependency - from lumo.utils.hash import hash_iter import inspect if not git_enable(): @@ -149,7 +148,7 @@ def on_start(self, exp: Experiment, *args, **kwargs): except OSError: pass - dep_hash = hash_iter(*dep_source) + dep_hash = hash(dep_source) commit_ = git_commit(key='lumo', info=exp.test_root, filter_files=filter_files) if commit_ is None: From 560c53c09f8b5e86f52de0d5009a34773dd635f7 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 15:27:34 +0800 Subject: [PATCH 05/99] remove unused code --- src/lumo/exp/__main__.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 src/lumo/exp/__main__.py diff --git a/src/lumo/exp/__main__.py b/src/lumo/exp/__main__.py deleted file mode 100644 index dd16f58..0000000 --- a/src/lumo/exp/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -列出某次实验的全部信息 -""" From f7fa88fa9d83082cfc923b2577ed76b4d51adcfb Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 15:30:35 +0800 Subject: [PATCH 06/99] Add docstring powered by chatgpt --- src/lumo/decorators/clsmethod.py | 11 ++- src/lumo/decorators/lazy_required.py | 18 +++++ src/lumo/decorators/process.py | 14 ++++ src/lumo/decorators/regist.py | 104 +++++++++++++++++++++++++-- 4 files changed, 137 insertions(+), 10 deletions(-) diff --git a/src/lumo/decorators/clsmethod.py b/src/lumo/decorators/clsmethod.py index 70b4682..b63c801 100644 --- a/src/lumo/decorators/clsmethod.py +++ b/src/lumo/decorators/clsmethod.py @@ -5,13 +5,18 @@ def clswrap(callable: T) -> T: """ - 带类型提示的 staticmethod() + A decorator for creating a staticmethod with type hints. + Args: - callable: + callable: The function to be wrapped with a staticmethod. Returns: + A new staticmethod that calls the original function. + Notes: - 必须在类中用 + This decorator should be used on a class method. It creates a new staticmethod that calls the original function, + allowing it to be called without an instance of the class. The `callable` argument should have type hints + for the parameters and return type. The resulting staticmethod will also have the same type hints. """ @staticmethod diff --git a/src/lumo/decorators/lazy_required.py b/src/lumo/decorators/lazy_required.py index 8814ffb..f6c6298 100644 --- a/src/lumo/decorators/lazy_required.py +++ b/src/lumo/decorators/lazy_required.py @@ -5,6 +5,16 @@ def is_lib_available(lib): + """ + Check if a library is available to be imported. + + Args: + lib (str): The name of the library to check for. + + Returns: + bool: True if the library is available, False otherwise. + + """ if lib in _lib_memory: return True res = imputil.find_spec(lib) @@ -16,6 +26,10 @@ def is_lib_available(lib): def torch_required(func): + """ + Wrap a function to raise an ImportError if PyTorch is not available. + """ + # Chose a different decorator name than in tests so it's clear they are not the same. @wraps(func) def wrapper(*args, **kwargs): @@ -28,6 +42,10 @@ def wrapper(*args, **kwargs): def lib_required(lib_name): + """ + Wrap a function to raise an ImportError if a required library is not available. + """ + def outer(func): @wraps(func) def wrapper(*args, **kwargs): diff --git a/src/lumo/decorators/process.py b/src/lumo/decorators/process.py index 9c1375b..5a48812 100644 --- a/src/lumo/decorators/process.py +++ b/src/lumo/decorators/process.py @@ -4,6 +4,20 @@ def call_on_main_process_wrap(func) -> Callable: + """ + Wrap a function to only execute on the main process. + + If the current process is the main process, it calls the original func with the passed arguments and returns the result. + If it is not the main process, it does nothing and returns None. + + Args: + func (callable): The function to wrap. + + Returns: + callable: A wrapped function that only executes on the main process. + + """ + def inner(*args, **kwargs): if is_main(): return func(*args, **kwargs) diff --git a/src/lumo/decorators/regist.py b/src/lumo/decorators/regist.py index 7c4bfe8..6829f96 100644 --- a/src/lumo/decorators/regist.py +++ b/src/lumo/decorators/regist.py @@ -4,6 +4,25 @@ def regist_func_to(val: Union[Dict[str, Callable], List[Callable]], name_=None): + """ + A decorator function that registers a function to a dictionary or list. + + Args: + val: A dictionary or list to register the function to. + name_: The key or index of the function in the dictionary or list. If not provided, it defaults to the name + of the function. + + Returns: + The decorated function. + + Notes: + This decorator function is used to register a function to a dictionary or list. If the `val` argument is a + dictionary, the function will be added to the dictionary with a key equal to `name_` or the function name if + `name_` is not provided. If the `val` argument is a list, the function will be appended to the list. The + registered function can then be retrieved and called later. The `name_` argument is optional, but if it is + provided it should be a string. The `val` argument should be either a dictionary or a list of functions. + """ + def wrap(func): if name_ is None: name = func.__name__ @@ -19,27 +38,98 @@ def wrap(func): return wrap -class Register(): - def __init__(self, name): +class Register: + """ + A class for registering functions. + + Args: + name: The name of the register. + + Attributes: + name: The name of the register. + source: An ordered dictionary that holds the registered functions. + + Methods: + __str__(): Returns a string representation of the register. + __repr__(): Returns a string representation of the register. + __getitem__(item): Gets a function from the register by name. + __call__(wrapped, name): A decorator function that adds a function to the register. + regist(name): A method that returns a partial function of __call__ with the register's name. + """ + + def __init__(self, name: str): + """ + Initialize the register. + + Args: + name: The name of the register. + """ self.name = name self.source = OrderedDict() - def __str__(self): + def __str__(self) -> str: + """ + Return a string representation of the register. + + Returns: + A string representation of the register. + """ inner = str([(k, v) for k, v in self.source.items()]) return f"Register({self.name}{inner})" - def __repr__(self): + def __repr__(self) -> str: + """ + Return a string representation of the register. + + Returns: + A string representation of the register. + """ return self.__str__() - def __getitem__(self, item): + def __getitem__(self, item: str): + """ + Get a function from the register by name. + + Args: + item: The name of the function. + + Returns: + The function with the given name, or None if the function is not in the register. + """ return self.source.get(item, None) - def __call__(self, wrapped, name=None): + def __call__(self, wrapped: callable, name: str = None) -> callable: + """ + Add a function to the register. + + Args: + wrapped: The function to be added to the register. + name: The name of the function in the register. If not provided, the function's name will be used. + + Returns: + The original function, unchanged. + + Raises: + AssertionError: If the `name` argument is not provided. + """ if name is None: name = wrapped.__name__ assert name is not None self.source[name] = wrapped return wrapped - def regist(self, name=None): + def regist(self, name: str = None): + """ + Returns a partial function of __call__ with the register's name. + + Args: + name: The name of the function in the register. If not provided, the function's name will be used. + + Returns: + A partial function of __call__ with the register's name. + + Notes: + This method is used to create a decorator that will add a function to the register. The `name` argument + is optional, but if it is provided it should be a string. + """ return partial(self, name=name) From d214190ab38cf67b5956e5705ec9bd398d0b94e9 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 15:44:45 +0800 Subject: [PATCH 07/99] Add docstring powered by chatgpt --- src/lumo/data/collate.py | 99 ++++++++++++++++++++++++++++++++++--- src/lumo/data/datamodule.py | 60 +++++++++++++++++++++- src/lumo/data/loader.py | 44 ++++++++++++++++- 3 files changed, 194 insertions(+), 9 deletions(-) diff --git a/src/lumo/data/collate.py b/src/lumo/data/collate.py index e56fbcb..e786083 100644 --- a/src/lumo/data/collate.py +++ b/src/lumo/data/collate.py @@ -1,7 +1,12 @@ """ -collate.py provided some useful classes and functions to act as the collect_fn in dataloader +A module that provides classes and functions to act as the collect_fn in dataloader. + +Classes: + CollateBase: A base class for implementing collate functions. + IgnoreNoneCollate: A collate function that ignores None samples. + """ -from typing import Any, Mapping, Sequence +from typing import Mapping, Sequence from lumo.core.params import ParamsType import numpy as np import torch @@ -10,13 +15,36 @@ class CollateBase: """ - Base Collate Fn, a common abstract of collate_fn - - DataLoader(collate_fn=YourCollateClass(...)) + A base class for implementing collate functions. + + Args: + params (ParamsType, optional): Parameters for the collate function. Defaults to None. + + Attributes: + _collate_fn (callable): A function that combines a list of samples into a batch. + params (ParamsType): Parameters for the collate function. + + Methods: + from_collate: Creates an instance of the class from a collate function. + __call__: Applies the collate function to a list of samples. + before_collate: A hook function called before the collate function is applied. + raw_collate: Applies the collate function to a list of samples without calling the `before_collate` and `after_collate` hook functions. + collate: Alias for `raw_collate`. + after_collate: A hook function called after the collate function is applied. """ @classmethod def from_collate(cls, collate_fn, params: ParamsType = None): + """ + Creates an instance of the class from a collate function. + + Args: + collate_fn (callable): A function that combines a list of samples into a batch. + params (ParamsType, optional): Parameters for the collate function. Defaults to None. + + Returns: + CollateBase: An instance of the class. + """ instance = cls(params) instance._collate_fn = collate_fn return instance @@ -27,25 +55,72 @@ def __init__(self, params: ParamsType = None) -> None: self.params = params def __call__(self, *args, **kwargs): + """ + Applies the collate function to a list of samples. + + Args: + *args: A list of samples. + **kwargs: Additional keyword arguments. + + Returns: + The batch of samples. + """ res = self.before_collate(*args, **kwargs) res = self._collate_fn(res) res = self.after_collate(res) return res def before_collate(self, sample_list): + """ + A hook function called before the collate function is applied. + + Args: + sample_list (Sequence): A list of samples. + + Returns: + Sequence: The modified list of samples. + """ return sample_list def raw_collate(self, sample_list): + """ + Applies the collate function to a list of samples without calling the `before_collate` and `after_collate` hook functions. + + Args: + sample_list (Sequence): A list of samples. + + Returns: + The batch of samples. + """ return self._collate_fn(sample_list) def collate(self, sample_list): + """ + Alias for `raw_collate`. + + Args: + sample_list (Sequence): A list of samples. + + Returns: + The batch of samples. + """ return self._collate_fn(sample_list) def after_collate(self, batch): + """ + A hook function called after the collate function is applied. + + Args: + batch (Any): The batch of samples. + + Returns: + Any: The modified batch of samples. + """ return batch class IgnoreNoneCollate(CollateBase): + """A collate function that ignores `None` samples.""" def _filter_none(self, item): if item is None: @@ -61,7 +136,19 @@ def before_collate(self, sample_list): def numpy_collate(batch): - r"""Puts each data field into a tensor with outer dimension batch size""" + """Collate function for numpy arrays. + + Args: + batch (list): A list of numpy arrays or other python objects. + + Returns: + numpy.ndarray or dict or list: Returns a collated numpy array, a dictionary of collated numpy arrays, + or a list of collated numpy arrays depending on the type of input elements. + + Raises: + RuntimeError: If the elements in batch do not have consistent size. + + """ elem = batch[0] elem_type = type(elem) diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py index 98bb5d2..b660296 100644 --- a/src/lumo/data/datamodule.py +++ b/src/lumo/data/datamodule.py @@ -1,3 +1,6 @@ +""" +A module to manage dataloaders for different stages of training. +""" from typing import NoReturn, Union, overload, Optional from torch.utils.data import DataLoader @@ -9,7 +12,7 @@ class DataModule: """ - Used in `Trainer` to easy access `DataLoader`s for different stage(train/test/eval/others). + Used in Trainer to easily access DataLoaders for different stage(train/test/eval/others). """ def __init__(self, params: ParamsType = None): @@ -18,10 +21,26 @@ def __init__(self, params: ParamsType = None): @property def prop(self): + """ + Get the dictionary of registered dataloaders. + + Returns: + dict: the dictionary of registered dataloaders. + """ return self._prop @staticmethod def _parse_dataset(loader): + """ + Parse a dataset from a dataloader. + + Args: + loader (Union[DataLoader, DataLoaderSide]): a dataloader or a dataloader side. + + Returns: + Dataset: a dataset if `loader` is a `DataLoader` or a `DataLoaderSide`, None otherwise. + + """ if isinstance(loader, DataLoader): return loader.dataset elif isinstance(loader, DataLoaderSide): @@ -30,29 +49,42 @@ def _parse_dataset(loader): @property def train_dataset(self): + """ + Get the train dataset. + + Returns: + Dataset: the train dataset. + + """ return self._parse_dataset(self.train_dataloader) @property def test_dataset(self): + """Get the test dataset.""" return self._parse_dataset(self.test_dataloader) @property def val_dataset(self): + """Get the validation dataset.""" return self._parse_dataset(self.val_dataloader) @property def train_dataloader(self) -> Optional[DataLoaderType]: + """Get the train dataloader.""" return self.get_loader_with_stage(TrainStage.train) @property def test_dataloader(self) -> Optional[DataLoaderType]: + """Get the test dataloader.""" return self.get_loader_with_stage(TrainStage.test) @property def val_dataloader(self) -> Union[NoReturn, DataLoaderType]: + """Get the val dataloader.""" return self.get_loader_with_stage(TrainStage.val) def get_loader_with_stage(self, stage: TrainStage) -> DataLoaderType: + """Get the dataloader for a given stage.""" res = self._prop.get(stage.value, None) if res is None: self.idataloader(self.params, stage) @@ -63,7 +95,17 @@ def __getitem__(self, key): return self.prop.get(key, None) @overload - def regist_dataloader(self, train=None, test=None, val=None): + def regist_dataloader(self, train=None, test=None, val=None, **kwargs): + + """ + Registers the given dataloaders under the given keys. + + Args: + train: A DataLoaderType object for the train set. + test: A DataLoaderType object for the test set. + val: A DataLoaderType object for the validation set. + **kwargs: A DataLoaderType object for other stage + """ ... def regist_dataloader(self, **kwargs: dict): @@ -71,7 +113,21 @@ def regist_dataloader(self, **kwargs: dict): self.prop[k] = v def regist_dataloader_with_stage(self, stage: TrainStage, dl: DataLoaderType): + """ + Registers the given dataloader under the given TrainStage. + + Args: + stage: A TrainStage object. + dl: A DataLoaderType object. + """ self.prop[stage.value] = dl def idataloader(self, params: ParamsType = None, stage: TrainStage = None): + """ + Interface function to implement in a child class to set up data loading. + + Args: + params: A ParamsType object containing data loading parameters. + stage: A TrainStage object indicating which stage to set up data loading for. + """ pass diff --git a/src/lumo/data/loader.py b/src/lumo/data/loader.py index e178b0f..f887200 100644 --- a/src/lumo/data/loader.py +++ b/src/lumo/data/loader.py @@ -10,6 +10,19 @@ class LumoDataLoader(DataLoader): def summarize_loader(loader: DataLoader): + """ + Summarize the DataLoader object and return a formatted string representation. + + Args: + loader: A DataLoader object. + + Returns: + A formatted string representation of the DataLoader object. + + Raises: + ValueError: If the input argument is not a DataLoader object. + + """ if isinstance(loader, DataLoaderSide): inner = pformat({f"{k}(cycle={loader._cycle[k]})": summarize_loader(v) for k, v in loader._loaders.items()}) return f"DataLoaderSide({inner})" @@ -40,7 +53,36 @@ def summarize_loader(loader: DataLoader): class DataLoaderSide: """ - `DataLoaderSide` is used when different DataLoader with different batch_size are feeded at the same time. + A utility class for loading data from different DataLoaders with different batch sizes at the same time. + + Example usage: + loader = DataLoaderSide() + loader.add('train', train_loader, cycle=True) + loader.add('val', val_loader) + loader.zip() + for batch in loader: + # process batch + + Attributes: + _loaders: An ordered dictionary that maps the name of the DataLoader to the corresponding DataLoader instance. + _cycle: An ordered dictionary that maps the name of the DataLoader to a boolean indicating whether the DataLoader should be cycled. + _state: A string that indicates the current state of the DataLoaderSide instance. The possible values are 'zip' and 'chain'. + + Methods: + dataset(): Returns a dictionary that maps the name of the DataLoader to its corresponding dataset. + source(): Returns the _loaders dictionary. + add(name, loader, cycle=False): Adds a DataLoader instance to the _loaders dictionary. + name is the name of the DataLoader. + loader is the DataLoader instance to be added. + cycle is a boolean indicating whether the DataLoader should be cycled. Defaults to False. + copy(): Returns a new DataLoaderSide instance with the same _loaders, _cycle, and _state attributes as the original. + zip(): Sets the _state attribute to 'zip', which means the batches are zipped together. + if _state is 'zip', the batches are returned as an ordered dictionary. + chain(): Sets the _state attribute to 'chain', which means the batches are concatenated. + if _state is 'chain', the batches are returned as a list. + len(): Returns the minimum length of all the DataLoaders that do not have the cycle flag set to True. + iter(): Returns an iterator that generates batches from the DataLoaders in the _loaders dictionary. + """ def __init__(self): From e272579aa3614ed61dfefdcd46e938936890a85c Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:17:45 +0800 Subject: [PATCH 08/99] Add docstring powered by chatgpt --- src/lumo/core/attr.py | 69 ++++++++++++++++ src/lumo/core/disk.py | 110 ++++++++++++++++++++++++-- src/lumo/core/enums.py | 3 + src/lumo/core/params.py | 169 ++++++++++++++++++++++++++++++---------- src/lumo/core/raises.py | 3 - src/lumo/core/record.py | 57 +++++++++++++- src/lumo/core/tree.py | 130 +++++++++++++++++++++++++++++-- 7 files changed, 485 insertions(+), 56 deletions(-) diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index 3c0bf61..e321fe7 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -6,6 +6,12 @@ class Attr(OrderedDict): + """ + A subclass of OrderedDict that allows you to access its elements via dot notation. + + This class overrides the __setattr__, __setitem__, __getattr__, and __getitem__ methods to provide the + dot notation functionality. + """ def __setattr__(self, key: str, value): set_item_iterative(self, key.split('.'), value) @@ -31,6 +37,22 @@ def __getitem__(self, key): def safe_update_dict(src: dict, kwargs: dict, assert_type=True): + """ + Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner. + + This function iterates over the items in the kwargs dictionary and updates the corresponding items in the + source dictionary, making sure that the types of the values being updated match the types of the values + already in the source dictionary. + + Args: + src (dict): The dictionary to update. + kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary. + assert_type (bool): A flag indicating whether to check that the types of the values being updated match + the types of the values already in the source dictionary. Defaults to True. + + Returns: + dict: The updated source dictionary. + """ for ks, v in walk_dict(kwargs): try: old_v = get_item_iterative(src, ks) @@ -46,6 +68,27 @@ def safe_update_dict(src: dict, kwargs: dict, assert_type=True): def walk_dict(dic: dict, root=None): + """ + Recursively walks through a dictionary and yields keys and values in a flattened format. + + Args: + - dic (dict): The dictionary to be walked through. + - root (list): The root keys to be used in the resulting flattened format. Defaults to None. + + Yields: + - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary. + + Example: + ```python + d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + for k, v in walk_dict(d): + print(k, v) + # Output: + # (['a', 'b'], 1) + # (['a', 'c', 'd'], 2) + # (['e'], 3) + ``` + """ if root is None: root = [] for k, v in dic.items(): @@ -56,6 +99,18 @@ def walk_dict(dic: dict, root=None): def set_item_iterative(dic: dict, keys: List[str], value): + """ + Sets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to update. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + value: The value to set for the nested key. + + Raises: + ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary. + + """ if len(keys) == 1: if isinstance(value, dict): for ks, v in walk_dict(value): @@ -76,6 +131,20 @@ def set_item_iterative(dic: dict, keys: List[str], value): def get_item_iterative(dic: dict, keys: List[str]): + """ + Gets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to retrieve the value from. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + + Raises: + KeyError: If the nested key does not exist in the dictionary. + + Returns: + The value of the nested key in the dictionary. + + """ if len(keys) == 1: return dict.__getitem__(dic, keys[0]) else: diff --git a/src/lumo/core/disk.py b/src/lumo/core/disk.py index 5f675d7..46c406a 100644 --- a/src/lumo/core/disk.py +++ b/src/lumo/core/disk.py @@ -1,7 +1,6 @@ import os.path from dbrecord import PList from lumo.proc import path -from lumo.utils.filelock2 import Lock from lumo.utils import safe_io as IO @@ -14,7 +13,6 @@ def __init__(self, test_path: str): os.makedirs(test_path, exist_ok=True) self.fpath = os.path.join(test_path, f'metric_board.sqlite') self.disk = PList(self.fpath) - self.lock = Lock(os.path.basename(test_path.rstrip('/'))) def append(self, metric: dict, step, stage='train'): self.disk.append({ @@ -30,10 +28,32 @@ def flush(self): class TableRow: """ - It can be regarded as a serialized dictionary, - or a certain row in the table, so the same key value will be overwritten. - - If you need to record records at different times, please use trainer.metrics + TableRow class is a serialized dictionary that can represents a single row in a table. + If the same key is updated, its value will be overwritten. + Please use trainer.metrics to record records at different times. + + Args: + - table (str): name of the table to which the row belongs. + - partition (str): partition of the row. + - rowkey (str): unique identifier of the row. + + Attributes: + - fpath (str): path of the file that stores the serialized row. + - key (str): unique identifier of the row. + - value (dict): dictionary representing the row. + + Methods: + - __enter__(self): context manager method. Does nothing. + - __exit__(self, exc_type, exc_val, exc_tb): context manager method. Calls flush method. + - flush(self): writes the value of the row to a file. + - update_metrics(self, dic: dict, compare=None, flush=False): updates multiple metrics in the row. + - update_metric(self, key, value, compare=None, flush=False): updates a single metric in the row. + - metric(self): returns the metric dictionary of the row. + - update_metric_pair(self, key, value, key2, value2, compare=None, flush=False): updates two metrics in the row. + - set_params(self, params: dict): sets the value of 'params' key in the row. + - update_dict(self, dic: dict, flush=False): updates multiple keys in the row. + - update(self, key, value, flush=True): updates a single key in the row. + - __getitem__(self, item): returns the value of a key in the row. """ def __init__(self, table, partition, rowkey): @@ -45,15 +65,33 @@ def __init__(self, table, partition, rowkey): # self.disk = PDict(self.fpath) def __enter__(self): + """ + Does nothing. + """ pass def __exit__(self, exc_type, exc_val, exc_tb): + """ + Calls flush method. Required for using the object as a context manager. + """ self.flush() def flush(self): + """Writes the value of the row to a file.""" IO.dump_pkl(self.value, self.fpath) def update_metrics(self, dic: dict, compare=None, flush=False): + """ + Updates multiple metrics in the row. + + Args: + - dic (dict): dictionary containing key-value pairs to be updated. + - compare (str): comparison operator to be used for updating metrics. Only 'max' and 'min' are supported. + - flush (bool): if True, writes the value of the row to a file after updating the metrics. + + Returns: + - res (dict): dictionary containing key-value pairs that were updated. + """ res = {} [res.update(self.update_metric(k, v, compare)) for k, v in dic.items()] @@ -62,6 +100,19 @@ def update_metrics(self, dic: dict, compare=None, flush=False): return res def update_metric(self, key, value, compare=None, flush=False): + """ + Updates a metric value in the row. + + Args: + key (str): The key of the metric. + value (float): The value of the metric. + compare (str, optional): The comparison operator used to compare the new value with the old one. + Either 'max' or 'min'. Default is None. + flush (bool, optional): Whether to flush the changes to disk. Default is False. + + Returns: + dict: A dictionary containing the updated metric key and value. + """ dic = self.metric older = dic.setdefault(key, None) @@ -89,9 +140,30 @@ def update_metric(self, key, value, compare=None, flush=False): @property def metric(self): + """ + A property that returns the metric values of the row. + + Returns: + dict: A dictionary containing the metric values of the row. + """ return self.value.setdefault('metric', {}) def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False): + """ + Update a pair of key-value metrics in the metric dictionary. + + Args: + key (str): The key of the first metric. + value (float): The value of the first metric. + key2 (str): The key of the second metric. + value2 (float): The value of the second metric. + compare (str, optional): The method to compare values. Default is None. + Possible values are 'max', 'min'. + flush (bool, optional): Whether to flush to disk after updating. Default is False. + + Returns: + dict: A dictionary with the old values of the updated metrics. + """ dic = self.metric old = dic.setdefault(key, None) old2 = dic.setdefault(key2, None) @@ -120,20 +192,44 @@ def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False return {key: old, key2: old2} def set_params(self, params: dict): + """ + Set the parameters dictionary of the row. + + Args: + params (dict): The parameters dictionary to set. + + Returns: + dict: The parameters dictionary set. + """ self.value['params'] = params self.flush() - return params def update_dict(self, dic: dict, flush=False): + """ + Update the row with a dictionary. + + Args: + dic (dict): The dictionary to update the row with. + flush (bool, optional): Whether to flush to disk after updating. Default is False. + """ for k, v in dic.items(): self.update(k, v) if flush: self.flush() def update(self, key, value, flush=True): + """ + Update a key-value pair in the row. + + Args: + key (str): The key of the metric to update. + value (float): The value to set the metric to. + flush (bool, optional): Whether to flush to disk after updating. Default is True. + """ self.value[key] = value if flush: self.flush() def __getitem__(self, item): + """Get the value of a key in the row.""" return self.value[item] diff --git a/src/lumo/core/enums.py b/src/lumo/core/enums.py index 3c931c9..1997270 100644 --- a/src/lumo/core/enums.py +++ b/src/lumo/core/enums.py @@ -2,6 +2,9 @@ class TrainStage(enum.Enum): + """ + Enumeration class representing different stages of training. + """ default = 'default' train = 'train' test = 'test' diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index 745877c..a14863a 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -12,7 +12,7 @@ from omegaconf._utils import _ensure_container from .attr import safe_update_dict, set_item_iterative -from .raises import BoundCheckError, NewParamWarning +from .raises import BoundCheckError # arange_param = namedtuple('arange_param', ['default', 'left', 'right'], defaults=[None, float('-inf'), float('inf')]) # choice_param = namedtuple('choice_param', ['default', 'choices'], defaults=[None, []]) @@ -21,6 +21,14 @@ class Arange: + """A class representing a range of numeric values with a default, left and right boundaries. + + Attributes: + default: The default value of the range. Defaults to None. + left: The left boundary of the range. Defaults to positive infinity. + right: The right boundary of the range. Defaults to positive infinity. + """ + def __init__(self, default=None, left=float('inf'), right=float('inf')): self.default = default self.left = left @@ -31,6 +39,13 @@ def __repr__(self): class Choices: + """A class representing a list of choices with a default value. + + Attributes: + default: The default value of the list. Defaults to None. + choices: A list of values representing the available choices. Defaults to an empty list. + """ + def __init__(self, default=None, choices=None): if choices is None: choices = [] @@ -41,31 +56,6 @@ def __repr__(self): return f"Choice: [{self.default}], {self.choices}" -def _get_item(dic, keys: List[str]): - if len(keys) == 1: - return DictConfig.__getitem__(dic, keys[0]) - else: - nex = DictConfig.__getitem__(dic, keys[0]) - if isinstance(nex, (dict, DictConfig)): - return _get_item(nex, keys[1:]) - else: - raise KeyError(keys) - - -def _set_item(dic, keys: List[str], value): - if len(keys) == 1: - if isinstance(value, dict): - value = dic(value) - DictConfig.__setitem__(dic, keys[0], value) - else: - try: - nex = _get_item(dic, keys[:1]) - except KeyError: - nex = DictConfig({}) - DictConfig.__setitem__(dic, keys[0], nex) - _set_item(nex, keys[1:], value) - - def _safe_repr(values: Any) -> str: return pformat(values) @@ -95,13 +85,15 @@ def _padding_mod(st: str, offset=7, mod=4): def safe_param_repr(values: List[tuple], level=1) -> str: """ + Returns a string representation of a list of tuples containing parameter names and their values. + The resulting string can be safely included in a function signature or call, as it correctly formats the parameters. Args: - values: - level: + values: A list of tuples containing parameter names and their values. + level: An integer representing the level of indentation. Returns: - + A string representation of the input parameters, formatted with correct indentation and line breaks. """ res = [(f"{k}={_safe_repr(v)},", anno) for k, v, anno in values] @@ -114,14 +106,10 @@ def safe_param_repr(values: List[tuple], level=1) -> str: class BaseParams(DictConfig): def __init__(self): super().__init__({}, flags={'no_deepcopy_set_nodes': True}) - # self._set_flag('no_deepcopy_set_nodes', True) self.__dict__["_prop"] = {} def __setattr__(self, key: str, value: Any) -> None: if key != '_prop': - # if isinstance(value, BaseParams): - # self._prop.setdefault('key_type', {})[key] = type(value) - if isinstance(value, (Arange, Choices)): res = self._prop.get('constrain', {}) res[key] = value @@ -135,9 +123,6 @@ def __setattr__(self, key: str, value: Any) -> None: def __setitem__(self, key: DictKeyType, value: Any) -> None: if key != '_prop': - # if isinstance(value, BaseParams): - # self._prop.setdefault('key_type', {})[key] = type(value) - if isinstance(value, (Arange, Choices)): self._prop.setdefault('constrain', {})[key] = value value = value.default @@ -149,9 +134,6 @@ def __setitem__(self, key: DictKeyType, value: Any) -> None: def __getattr__(self, key: str) -> Any: res = super().__getattr__(key) - # key_type = self._prop.setdefault('key_type', {}).get(key, None) - # if key_type is not None: - # res = key_type.from_kwargs(**res) return res def _check(self, name, value): @@ -237,26 +219,81 @@ def choice(self, *choices) -> Choices: return Choices(choices[0], choices) def safe_update(self, dic, assert_type=True): + """ + Merge `dict` object into the config object, safely updating the values. + + Args: + dic: `dict` object to update + assert_type: If True, enforce that the type of values in `dic` matches the current config. + + Returns: + None + """ self.update( safe_update_dict(self.to_dict(), dic, assert_type=assert_type) ) def from_dict(self, dic: dict): + """ + Update the config object from a dictionary. + + Args: + dic: `dict` object to update + + Returns: + updated `self` object + """ self.safe_update(dic) return self def from_kwargs(self, **kwargs): + """ + Update the config object from keyword arguments. + + Args: + **kwargs: key-value pairs to update in the config object + + Returns: + updated `self` object + """ return self.from_dict(kwargs) def from_json(self, file): + """ + Update the config object from a JSON file. + + Args: + file: path to the JSON file + + Returns: + updated `self` object + """ self.safe_update(json.loads(Path(file).read_text()), assert_type=True) return self def from_yaml(self, file): + """ + Update the config object from a YAML file. + + Args: + file: path to the YAML file + + Returns: + updated `self` object + """ self.safe_update(dict(OmegaConf.load(file)), assert_type=True) return self def from_args(self, argv: list = None): + """ + Update the config object from command line arguments. + + Args: + argv: list of command line arguments (default: None) + + Returns: + updated `self` object + """ if argv is None: argv = sys.argv @@ -292,17 +329,41 @@ def inner(cfg): return self def to_dict(self): + """ + Convert this configuration to a dictionary. + + Returns: + dict: The configuration as a dictionary. + """ cfg = _ensure_container(self) container = OmegaConf.to_container(cfg, resolve=False, enum_to_str=True) return container def to_json(self, file=None): + """ + Convert this configuration to a JSON string. + + Args: + file (str or Path, optional): If specified, the JSON string will be written to a file at the given path. + + Returns: + str or None: The JSON string, or None if file is specified. + """ info = json.dumps(self.to_dict(), ensure_ascii=False, indent=2) if file is None: return info return Path(file).write_text(info, encoding='utf-8') def to_yaml(self, file=None): + """ + Convert this configuration to a YAML string. + + Args: + file (str or Path, optional): If specified, the YAML string will be written to a file at the given path. + + Returns: + str or None: The YAML string, or None if file is specified. + """ info = OmegaConf.to_yaml(self) if file is None: return info @@ -310,12 +371,33 @@ def to_yaml(self, file=None): @classmethod def Space(cls, **kwargs): + """ + Create a configuration object from keyword arguments. + + Args: + **kwargs: The configuration values. + + Returns: + BaseParams: The new configuration object. + """ return cls().from_dict(kwargs) def __hash__(self): + """ + Calculate the hash value of this configuration. + + Returns: + int: The hash value. + """ return int(self.hash(), 16) def hash(self) -> str: + """ + Calculate the hash value of this configuration. + + Returns: + str: The hash value. + """ return hash(self.to_dict()) def iparams(self): @@ -323,6 +405,15 @@ def iparams(self): @classmethod def init_from_kwargs(cls, **kwargs): + """ + Create a configuration object from keyword arguments. + + Args: + **kwargs: The configuration values. + + Returns: + BaseParams: The new configuration object. + """ return cls().from_dict(kwargs) diff --git a/src/lumo/core/raises.py b/src/lumo/core/raises.py index 01945f6..c547063 100644 --- a/src/lumo/core/raises.py +++ b/src/lumo/core/raises.py @@ -1,4 +1 @@ class BoundCheckError(BaseException): pass - - -class NewParamWarning(Warning): pass diff --git a/src/lumo/core/record.py b/src/lumo/core/record.py index 7d1cda2..e8cd080 100644 --- a/src/lumo/core/record.py +++ b/src/lumo/core/record.py @@ -20,7 +20,6 @@ def wrap_result(metric: MetricType) -> Meter: """ Wrap any form of metric data into Meter form. - """ if isinstance(metric, (Meter,)): return metric @@ -34,17 +33,37 @@ def wrap_result(metric: MetricType) -> Meter: class Record: - def __init__(self, window_size=500, **kwargs): + """Record class for storing and computing metrics over a window of steps. + + Attributes: + **kwargs: Additional properties to add to the record object. + + Properties: + stage (str): The stage of the training process, such as 'train' or 'eval'. + + Methods: + avg(): Computes the average value of the recorded metrics. + agg(): Computes the aggregated value of the recorded metrics. + __str__(): Returns a string representation of the recorded metrics. + tostr(): Alias for __str__(). + record(metric, global_step=None): Records a metric for the current step. + clear(): Clears the recorded metrics. + flush(): Clears the cache of recorded metrics. + """ + + def __init__(self, **kwargs): self._prop = {} self._prop.update(kwargs) self._cache = [] self._agg = OrderedDict() # type:Dict[str,AggItem] def avg(self) -> Attr: + """DEPRECATED: Computes the average value of the recorded metrics.""" warnings.warn('avg will be deprecated in later version, use agg() instead.') return self.agg() def agg(self) -> Attr: + """Computes the aggregated value of the recorded metrics.""" res = Attr() for k, v in self._agg.items(): res[k] = v.res @@ -52,9 +71,11 @@ def agg(self) -> Attr: @property def stage(self): + """Gets the stage of the training process, such as 'train' or 'eval'.""" return self._prop['stage'] def __str__(self): + """Returns a string representation of the recorded metrics.""" res = self.agg() rep = [] for k, v in res.items(): @@ -66,9 +87,11 @@ def __str__(self): return ', '.join(rep) def tostr(self): + """Alias for __str__().""" return str(self) def record(self, metric, global_step=None): + """Records a metric for the current step.""" meter = wrap_result(metric) agg = meter._avg @@ -81,14 +104,26 @@ def record(self, metric, global_step=None): self._agg[k] = item def clear(self): + """Clears the recorded metrics.""" self._agg.clear() self._cache.clear() def flush(self): + """Clears the cache of recorded metrics.""" self._cache.clear() class AggItem: + """ + A class that aggregates a sequence of values according to a specified strategy. + + Attributes: + stg (str): A string that specifies the strategy to be used for aggregation. + _last (int): The last value added to the aggregation. + acc (int): The accumulated value after aggregation. + c (int): The count of values added to the aggregation. + """ + def __init__(self, stg): self.stg = stg self._last = 0 @@ -97,6 +132,12 @@ def __init__(self, stg): @property def res(self): + """ + Computes the result of the aggregation. + + Returns: + int: The result of the aggregation according to the specified strategy. + """ if self.stg == 'mean': return self.acc / self.c @@ -110,9 +151,21 @@ def res(self): @property def last(self): + """ + Returns the last value added to the aggregation by `update`. + + Returns: + int: The last value added to the aggregation. + """ return self._last def update(self, val): + """ + Updates the aggregation with a new value. + + Args: + val (int): The new value to add to the aggregation. + """ if self.stg == 'last': self.acc = val elif self.stg == 'min': diff --git a/src/lumo/core/tree.py b/src/lumo/core/tree.py index 93255d4..5131349 100644 --- a/src/lumo/core/tree.py +++ b/src/lumo/core/tree.py @@ -1,14 +1,24 @@ -""" - -""" -import queue from collections import defaultdict class tree(dict): - """Implementation of perl's autovivification feature.""" + """Implements Perl's autovivification feature. + + This class extends Python's built-in dict class to allow for automatic creation of nested dictionaries + on access of non-existent keys. It accomplishes this by overriding the __getitem__ method to recursively + create new nested trees on access of a non-existent key. Additionally, the walk method is provided to + allow for iterating over all keys and values in the tree. + """ def __getitem__(self, item): + """Override __getitem__ to automatically create new trees on access of non-existent keys. + + Args: + item: The key being accessed. + + Returns: + The value of the key. + """ try: return dict.__getitem__(self, item) except KeyError: @@ -16,6 +26,14 @@ def __getitem__(self, item): return value def walk(self): + """Iterate over all keys and values in the tree. + + This method yields a tuple containing the current key and value, and recursively calls itself + on any nested tree values. + + Yields: + A tuple containing the current key and value. + """ for k, v in self.items(): yield k, v if isinstance(v, tree): @@ -24,6 +42,16 @@ def walk(self): class Node: + """Represents a node in the Forest. + + Attributes: + HEAD (int): A class variable indicating the head node. + MID (int): A class variable indicating a mid node. + TAIL (int): A class variable indicating the tail node. + value (any): The value held by the node. + link (list): A list of the node's adjacent nodes. + stage (int): The position of the node in the linked list. + """ HEAD = 0 MID = 1 TAIL = 2 @@ -35,25 +63,49 @@ def __init__(self): @property def is_head(self): + """Returns True if the node is a head node, else False.""" return self.stage == self.HEAD @property def is_mid(self): + """Returns True if the node is a mid node, else False.""" return self.stage == self.MID @property def is_tail(self): + """Returns True if the node is a tail node, else False.""" return self.stage == self.TAIL def set_stage(self, stage): + """Sets the stage attribute of the node to the specified value. + + Args: + stage (int): An integer indicating the position of the node in the linked list. + + Returns: + Node: The node with the updated stage attribute. + """ self.stage = stage return self def set_value(self, val): + """Sets the value attribute of the node to the specified value. + + Args: + val (any): The value to set the node's value attribute to. + + Returns: + Node: The node with the updated value attribute. + """ self.value = val return self def add_link(self, y): + """Adds the specified node to the list of adjacent nodes. + + Args: + y (Node): The node to add to the list of adjacent nodes. + """ self.link.append(y) def __repr__(self): @@ -61,20 +113,66 @@ def __repr__(self): class Forest: + """ + Represents a directed acyclic graph (DAG) where nodes are categorized as head, mid or tail. + + Attributes: + dic (defaultdict): A dictionary to store the nodes of the graph. + order (list): A list to maintain the order of the nodes. + tail (set): A set to store the tail nodes of the graph. + + """ + def __init__(self): self.dic = defaultdict(Node) self.order = [] self.tail = set() def add_head(self, x, val=None): + """ + Adds a new head node to the graph with the given value. + + Args: + x: The node to be added. + val: The value associated with the node. Defaults to None. + + Returns: + The updated Forest object. + + """ self.dic[x].set_value(val).set_stage(Node.HEAD) self.order.append(x) return self def check_node_type(self, x): + """ + Checks if a node is already present in the graph. + + Args: + x: The node to be checked. + + Returns: + True if the node is already present, False otherwise. + + """ return x in self.dic def add_link(self, x, y, y_val=None): + """ + Adds a new mid node to the graph and links it with an existing head node. + + Args: + x: The head node to be linked with the new mid node. + y: The new mid node to be added. + y_val: The value associated with the new mid node. Defaults to None. + + Returns: + The updated Forest object. + + Raises: + AssertionError: If the head node is not already present in the graph or the mid node is already present. + + """ assert x in self.dic, f'x must already existed in graph, has {self.order}, got {x}' assert y not in self.dic, f'y must be a new node in graph, has {self.order}, got {y}' self.dic[x].add_link(y) @@ -83,6 +181,21 @@ def add_link(self, x, y, y_val=None): return self def add_tail(self, x, y, y_val=None): + """ + Adds a new tail node to the graph and links it with an existing head or mid node. + + Args: + x: The node to be linked with the new tail node. + y: The new tail node to be added. + y_val: The value associated with the new tail node. Defaults to None. + + Returns: + The updated Forest object. + + Raises: + AssertionError: If the head or mid node is not already present in the graph or the tail node is already present. + + """ assert x in self.dic, f'x must already existed in graph, has {self.order}, got {x}' assert y not in self.dic, f'y must be a new node in graph, has {self.order}, got {y}' self.dic[x].add_link(y) @@ -92,6 +205,13 @@ def add_tail(self, x, y, y_val=None): return self def __iter__(self): + """ + Returns an iterator that iterates over the nodes in the graph in a Breadth-First-Search (BFS) order. + + Returns: + An iterator that yields a tuple with the node and its corresponding value. + + """ stack = [] mem = set() From f9af62e515f7e81f1034847c855b796972c6e4ec Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:17:52 +0800 Subject: [PATCH 09/99] Remove unused file --- src/lumo/core/README.md | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 src/lumo/core/README.md diff --git a/src/lumo/core/README.md b/src/lumo/core/README.md deleted file mode 100644 index 9e42711..0000000 --- a/src/lumo/core/README.md +++ /dev/null @@ -1,10 +0,0 @@ - - - 数据结构 - - Meter,指标记录 - - Params,参数管理 - - Scheduler - - Logger,日志管理 - - 全局变量、proc - - 线程管理 - - 装饰器管理 - - device 管理 \ No newline at end of file From 6296ff628d30bda779c1f16b0394b5a89fb03e43 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:18:10 +0800 Subject: [PATCH 10/99] Fix dreprecated usage --- src/lumo/trainer/callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index e6c7a0b..42d45e8 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -519,7 +519,7 @@ def __init__(self, per_epoch=50): def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Optional[Record], *args, **kwargs): - meter = record.avg() + meter = record.agg() if trainer.eidx % self.per_epoch == 0 and trainer.eidx > 0: trainer.save_checkpoint(meta_info=Meter.wrap_result(meter)) @@ -664,15 +664,15 @@ def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): super().on_train_epoch_end(trainer, func, params, record, *args, **kwargs) - self.log(record.avg(), step=trainer.global_steps, namespace='train.epoch') + self.log(record.agg(), step=trainer.global_steps, namespace='train.epoch') def on_test_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): super().on_test_end(trainer, func, params, record, *args, **kwargs) - self.log(record.avg(), step=trainer.global_steps, namespace='test') + self.log(record.agg(), step=trainer.global_steps, namespace='test') def on_eval_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): super().on_eval_end(trainer, func, params, record, *args, **kwargs) - self.log(record.avg(), step=trainer.global_steps, namespace='evaluate') + self.log(record.agg(), step=trainer.global_steps, namespace='evaluate') class WandbCallback(RecordCallback): From 8ce9203ad4d30aeb9cc98b5ff61373dd07f376df Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:18:29 +0800 Subject: [PATCH 11/99] Move for useless --- src/lumo/{decorators => sketch}/map_extract.py | 0 src/lumo/{utils => sketch}/timer.py | 17 +---------------- 2 files changed, 1 insertion(+), 16 deletions(-) rename src/lumo/{decorators => sketch}/map_extract.py (100%) rename src/lumo/{utils => sketch}/timer.py (82%) diff --git a/src/lumo/decorators/map_extract.py b/src/lumo/sketch/map_extract.py similarity index 100% rename from src/lumo/decorators/map_extract.py rename to src/lumo/sketch/map_extract.py diff --git a/src/lumo/utils/timer.py b/src/lumo/sketch/timer.py similarity index 82% rename from src/lumo/utils/timer.py rename to src/lumo/sketch/timer.py index e509aa5..ffe0a1b 100644 --- a/src/lumo/utils/timer.py +++ b/src/lumo/sketch/timer.py @@ -6,22 +6,7 @@ import warnings from collections import OrderedDict -from .fmt import strftime - - -def format_second(sec: int) -> str: - """convert seconds from int to string""" - sec, ms = divmod(sec, 1) - if sec > 60: - min, sec = divmod(sec, 60) - if min > 60: - hour, min = divmod(min, 60) - fmt = "{}h{}m{}s".format(hour, min, int(sec)) - else: - fmt = "{}m{}s".format(min, int(sec)) - else: - fmt = "{}s".format(int(sec)) - return fmt +from lumo.utils.fmt import strftime class Timer: From 65d7e5b56275d9d51b9fd4d7f5a50826733bd51a Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:18:50 +0800 Subject: [PATCH 12/99] Replace Filelock by stable library --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2329b93..d7286f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ dbrecord packaging pandas hydra-core -tensorboard \ No newline at end of file +tensorboard +filelock \ No newline at end of file From 3277a17ad46aa211c03cfe8d940bcd2680dc47e2 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 16:19:23 +0800 Subject: [PATCH 13/99] Fix Deprecated warning --- examples/quick_start.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/quick_start.py b/examples/quick_start.py index c77e66d..3a751cd 100644 --- a/examples/quick_start.py +++ b/examples/quick_start.py @@ -18,10 +18,10 @@ def add(x): db = ( DatasetBuilder() - .add_input("xs", torch.arange(-2500, 2500, dtype=torch.float).unsqueeze(1)) - .add_input("ys", torch.arange(-2500, 2500, dtype=torch.float), transform=add) - .add_output("xs", "xs") - .add_output("ys", "ys") + .add_input("xs", torch.arange(-2500, 2500, dtype=torch.float).unsqueeze(1)) + .add_input("ys", torch.arange(-2500, 2500, dtype=torch.float), transform=add) + .add_output("xs", "xs") + .add_output("ys", "ys") ) loader = db.DataLoader(batch_size=params.batch_size, shuffle=True) @@ -50,12 +50,11 @@ def add(x): meter.mean.loss = loss meter.sum.c = 1 record.record(meter) - logger.inline(record.avg()) + logger.inline(record.agg()) logger.newline() test_data = -torch.arange(1000, dtype=torch.float).unsqueeze(1) res = model(test_data) - test_ys = test_data * 2 print((res.long() - test_ys.long()).sum()) From e4ce4b435d12e9eb6565cb59bc3f270da29c7479 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 17:14:42 +0800 Subject: [PATCH 14/99] Append tests (coverage rate up to 73%) --- src/lumo/core/attr.py | 2 - src/lumo/core/interp.py | 6 --- src/lumo/core/record.py | 71 ++--------------------------------- src/lumo/exp/exphook.py | 13 ++++--- src/lumo/proc/config.py | 4 +- src/lumo/trainer/saver.py | 7 ++-- tests/core/test_meter.py | 2 + tests/trainer/test_saver.py | 5 ++- tests/trainer/test_trainer.py | 2 + 9 files changed, 24 insertions(+), 88 deletions(-) diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index e321fe7..bf6c8c5 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -58,12 +58,10 @@ def safe_update_dict(src: dict, kwargs: dict, assert_type=True): old_v = get_item_iterative(src, ks) if old_v is None or isinstance(old_v, type(v)): set_item_iterative(src, ks, v) - # print(ks, v) else: raise TypeError(ks, type(old_v), type(v)) except KeyError: set_item_iterative(src, ks, v) - # print(ks, v) return src diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py index 100f445..11e4a4f 100644 --- a/src/lumo/core/interp.py +++ b/src/lumo/core/interp.py @@ -44,11 +44,6 @@ class Interpolate(BaseParams): ratio 变化为从 1 - 0 """ - def toggle_constant(self, toggle=True): - """fix the schedule as the first value""" - self.constant = toggle - return self - @classmethod def interp(self, *args, **kwargs): raise NotImplementedError() @@ -234,7 +229,6 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): return start * cos_ratio + end * (1 - cos_ratio) - class Linear(ABCContinuous): """linear schedule""" diff --git a/src/lumo/core/record.py b/src/lumo/core/record.py index e8cd080..c29f2e3 100644 --- a/src/lumo/core/record.py +++ b/src/lumo/core/record.py @@ -2,7 +2,7 @@ from numbers import Number from . import Attr -from .meter import Meter +from .meter import Meter, ReduceItem import torch import numpy as np from typing import NewType, Mapping, Union, Sequence, Dict @@ -55,7 +55,7 @@ def __init__(self, **kwargs): self._prop = {} self._prop.update(kwargs) self._cache = [] - self._agg = OrderedDict() # type:Dict[str,AggItem] + self._agg = OrderedDict() # type:Dict[str,ReduceItem] def avg(self) -> Attr: """DEPRECATED: Computes the average value of the recorded metrics.""" @@ -99,7 +99,7 @@ def record(self, metric, global_step=None): stg = agg.get(k, 'last') item = self._agg.get(k, None) if item is None: - item = AggItem(stg) + item = ReduceItem(stg) item.update(v) self._agg[k] = item @@ -111,68 +111,3 @@ def clear(self): def flush(self): """Clears the cache of recorded metrics.""" self._cache.clear() - - -class AggItem: - """ - A class that aggregates a sequence of values according to a specified strategy. - - Attributes: - stg (str): A string that specifies the strategy to be used for aggregation. - _last (int): The last value added to the aggregation. - acc (int): The accumulated value after aggregation. - c (int): The count of values added to the aggregation. - """ - - def __init__(self, stg): - self.stg = stg - self._last = 0 - self.acc = 0 - self.c = 0 - - @property - def res(self): - """ - Computes the result of the aggregation. - - Returns: - int: The result of the aggregation according to the specified strategy. - """ - if self.stg == 'mean': - return self.acc / self.c - - if self.stg in {'min', 'max', 'last'}: - return self.acc - - if self.stg == 'sum': - return self.acc - - return self.acc - - @property - def last(self): - """ - Returns the last value added to the aggregation by `update`. - - Returns: - int: The last value added to the aggregation. - """ - return self._last - - def update(self, val): - """ - Updates the aggregation with a new value. - - Args: - val (int): The new value to add to the aggregation. - """ - if self.stg == 'last': - self.acc = val - elif self.stg == 'min': - self.acc = min(self.acc, val) - elif self.stg == 'max': - self.acc = max(self.acc, val) - elif self.stg in {'mean', 'sum'}: - self.acc += val - self.c += 1 - self._last = val diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index 669385c..beff061 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -55,10 +55,10 @@ def on_start(self, exp: Experiment, *args, **kwargs): os.chmod(fn, st.st_mode | stat.S_IEXEC) -class PathRecord(ExpHook): - - def on_newpath(self, exp: Experiment, *args, **kwargs): - super().on_newpath(exp, *args, **kwargs) +# class PathRecord(ExpHook): +# +# def on_newpath(self, exp: Experiment, *args, **kwargs): +# super().on_newpath(exp, *args, **kwargs) class Diary(ExpHook): @@ -98,7 +98,7 @@ def _create_agent(self, exp: Experiment): ] subprocess.Popen(' '.join(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, - start_new_session=True) + start_new_session=True, cwd=os.getcwd(), env=os.environ) def on_start(self, exp: Experiment, *args, **kwargs): super().on_start(exp) @@ -178,7 +178,8 @@ def on_start(self, exp: Experiment, *args, **kwargs): class LockFile(ExpHook): def on_start(self, exp: Experiment, *args, **kwargs): - exp.dump_info('lock', get_lock('torch', 'numpy', + exp.dump_info('lock', get_lock('lumo', + 'numpy', 'joblib', 'psutil', 'decorator', diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py index e3b7439..e5b9164 100644 --- a/src/lumo/proc/config.py +++ b/src/lumo/proc/config.py @@ -59,9 +59,9 @@ def debug_mode(base_dir=None, disable_git=True): glob['cache_dir'] = tempfile.mkdtemp(dir=base_dir) glob['blob_root'] = tempfile.mkdtemp(dir=base_dir) glob['metric_root'] = tempfile.mkdtemp(dir=base_dir) - glob['HOOK_LOCKFILE'] = False + # glob['HOOK_LOCKFILE'] = False glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp(dir=base_dir) - glob['HOOK_RECORDABORT'] = False + # glob['HOOK_RECORDABORT'] = False glob['HOOK_TIMEMONITOR'] = False if disable_git: diff --git a/src/lumo/trainer/saver.py b/src/lumo/trainer/saver.py index e1215e9..3a62cff 100644 --- a/src/lumo/trainer/saver.py +++ b/src/lumo/trainer/saver.py @@ -20,6 +20,7 @@ def __getitem__(self, item): return self.meta_info raise IndexError(item) + class Saver: """ Write state_dict into test dirs, record save log into /.lumo/save..log @@ -86,13 +87,13 @@ def dump_state_dict(self, obj, fn, meta_info: Union[str, dict] = None): Returns: saved filepath, None if something went wrong. """ - res = io.dump_state_dict(obj, fn) - if res and meta_info is not None: + io.dump_state_dict(obj, fn) + if meta_info is not None: if isinstance(meta_info, str): meta_info = {'msg': meta_info} json_fn = f"{fn}.json" io.dump_json(meta_info, json_fn) - return res + return fn def load_state_dict(self, fn: str, with_meta=False, map_location='cpu') -> state_dict_tuple: """ diff --git a/tests/core/test_meter.py b/tests/core/test_meter.py index be5f8c1..9f42f72 100644 --- a/tests/core/test_meter.py +++ b/tests/core/test_meter.py @@ -1,3 +1,5 @@ +import pytest + from lumo.core.meter import ReduceItem diff --git a/tests/trainer/test_saver.py b/tests/trainer/test_saver.py index f7cd220..766abcc 100644 --- a/tests/trainer/test_saver.py +++ b/tests/trainer/test_saver.py @@ -1,3 +1,5 @@ +import os + from lumo.trainer.saver import Saver import time import shutil @@ -17,10 +19,11 @@ def test_save_load(): else: saver.save_checkpoint(i, {'step': i}, {'meta_step': i}, max_keep=max_keep, is_best=(i % 5) == 0) + print(os.listdir(save_root)) assert len(saver.list_models()) == (epoch) state = saver.load_model(best_if_exist=True, with_meta=True) - print(state) + # print(state) assert state[0]['step'] == 5 and state[1]['meta_step'] == 5 state = saver.load_model(2) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 35415bb..0af44cd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -97,6 +97,8 @@ def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, limit_s limit_global_steps=None) -> Record: assert self.context == 'train_epoch' assert self.contexts[-2] == 'train' + if params.get('raise_exp', False): + raise ValueError('raised by test') return super().train_epoch(loader, params, limit_step, limit_global_steps) def remove_callback(self, cur): From c7ce67259780642bad515ba784765f836922bb1d Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 5 Mar 2023 17:14:47 +0800 Subject: [PATCH 15/99] Append tests (coverage rate up to 73%) --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 0322e2f..26c377c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ omit = [ 'src/lumo/cli/*', 'src/lumo/vis/*', 'src/lumo/decorators/*', + 'src/lumo/exp/agent.py', 'src/lumo/analyse/*', 'src/lumo/sketch/*', 'src/lumo/core/record_backend/*', @@ -73,6 +74,12 @@ exclude_lines = [ "if 0:", "if __name__ == .__main__.:", "if TYPE_CHECKING:", + # some tricky object that can not well tested. + "except ImportError", + "class RecordAbort", + "pass", + "def summary_experiment", + "def plot", ] [tool.coverage.html] From cc270f122d8a765a73a28204bb4a5dfdb3a852ff Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 09:30:20 +0800 Subject: [PATCH 16/99] Remove timemonitor exphook because of using progress recording in Experiment --- src/lumo/exp/experiment.py | 12 +++++----- src/lumo/exp/exphook.py | 46 +++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 6738f69..1a90388 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -4,7 +4,7 @@ import time import traceback from pathlib import Path -from typing import Union +from typing import Union, Any from lumo.decorators.process import call_on_main_process_wrap from lumo.proc import glob @@ -169,12 +169,12 @@ def test_branch(self): return checkdir(val) def dump_progress(self, ratio: float, update_from=None): - res = {'ratio': ratio} + res = {'ratio': max(min(ratio, 1), 0)} if update_from is None: res['update_from'] = update_from self.dump_info('progress', res, append=True) - def dump_info(self, key: str, info: dict, append=False, info_dir='info', set_prop=True): + def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop=True): fn = self.test_file(f'{key}.json', info_dir) if append: old_info = self.load_info(key, info_dir=info_dir) @@ -336,7 +336,7 @@ def start(self): if self.get_prop('start', False): return self.initial() - self.set_prop('start', True) + self.dump_info('start', True) for hook in self._hooks.values(): # type: ExpHook hook.on_start(self) return self @@ -348,7 +348,7 @@ def end(self, end_code=0, *args, **extra): if self.get_prop('end', False): return self.dump_progress(1) - self.set_prop('end', True) + self.dump_info('end', True) for hook in self._hooks.values(): # type: ExpHook hook.on_end(self, end_code=end_code, *args, **extra) return self @@ -475,5 +475,5 @@ def __init__(self, exp_name: str, root=None): self.set_hook(exphook.GitCommit()) self.set_hook(exphook.RecordAbort()) self.set_hook(exphook.Diary()) - self.set_hook(exphook.TimeMonitor()) + # self.set_hook(exphook.TimeMonitor()) self.set_hook(exphook.FinalReport()) diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index beff061..41969e3 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -84,29 +84,29 @@ def exc_end(self, exc_type, exc_val, exc_tb): ) -class TimeMonitor(ExpHook): - def _create_agent(self, exp: Experiment): - from lumo.exp import agent - cmd = [ - sys.executable, '-m', agent.__spec__.name, - f"--state_key=state", - f"--pid={os.getpid()}", - f"--exp_name={exp.exp_name}", - f"--test_name={exp.test_name}", - f"--test_root={exp.test_root}", - # f"--params={sys.argv}" # TODO add sys.argv - ] - subprocess.Popen(' '.join(cmd), - stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, - start_new_session=True, cwd=os.getcwd(), env=os.environ) - - def on_start(self, exp: Experiment, *args, **kwargs): - super().on_start(exp) - self._create_agent(exp) - exp.dump_info('state', { - 'start': strftime(), - 'end': strftime() - }) +# class TimeMonitor(ExpHook): +# def _create_agent(self, exp: Experiment): +# from lumo.exp import agent +# cmd = [ +# sys.executable, '-m', agent.__spec__.name, +# f"--state_key=state", +# f"--pid={os.getpid()}", +# f"--exp_name={exp.exp_name}", +# f"--test_name={exp.test_name}", +# f"--test_root={exp.test_root}", +# # f"--params={sys.argv}" +# ] +# subprocess.Popen(' '.join(cmd), +# stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, +# start_new_session=True, cwd=os.getcwd(), env=os.environ) +# +# def on_start(self, exp: Experiment, *args, **kwargs): +# super().on_start(exp) +# self._create_agent(exp) +# exp.dump_info('state', { +# 'start': strftime(), +# 'end': strftime() +# }) class GitCommit(ExpHook): From d6d4fa00f7d869967e09739e191391134e92e5cc Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 09:30:32 +0800 Subject: [PATCH 17/99] Add PowerDecay2 to __all__ --- src/lumo/core/interp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py index 11e4a4f..ad686bc 100644 --- a/src/lumo/core/interp.py +++ b/src/lumo/core/interp.py @@ -36,6 +36,7 @@ 'PeriodTriangle', 'PeriodLinear', 'PowerDecay', + 'PowerDecay2', 'InterpolateList', ] @@ -401,7 +402,7 @@ def interp(cls, cur, start=0., gammas=None, schedules=None, *args, **kwargs): return res def __call__(self, cur): - self.interp(cur, start=self.start, gammas=self.gammas, schedules=self.schedules) + return self.interp(cur, start=self.start, gammas=self.gammas, schedules=self.schedules) class InterpolateList(Interpolate): From c098780eb5a41da6e2c93e9d8b2acaacb1aaaae4 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 09:30:47 +0800 Subject: [PATCH 18/99] Default Ordered Meter --- src/lumo/core/meter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lumo/core/meter.py b/src/lumo/core/meter.py index a8ef163..87c63a8 100644 --- a/src/lumo/core/meter.py +++ b/src/lumo/core/meter.py @@ -15,16 +15,16 @@ class Meter: def __init__(self): self._prop = {} - self._rec = {} + self._rec = OrderedDict() self._avg = {} def sorted(self) -> 'Meter': m = Meter() - m._rec = self._rec - m._avg = OrderedDict() + m._prop = self._prop for k in sorted(self._avg.keys()): m._avg[k] = self._avg[k] + m._rec[k] = self._rec[k] return m def todict(self): From e0a5eea465aaa147990efd0901dc1b45aeedf722 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 09:32:05 +0800 Subject: [PATCH 19/99] Add MemoryBank and StorageBank --- src/lumo/contrib/module/memoty_bank.py | 98 ++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 7 deletions(-) diff --git a/src/lumo/contrib/module/memoty_bank.py b/src/lumo/contrib/module/memoty_bank.py index f3da0c5..d080c87 100644 --- a/src/lumo/contrib/module/memoty_bank.py +++ b/src/lumo/contrib/module/memoty_bank.py @@ -1,25 +1,56 @@ -import torch from accelerate.utils import gather +from lumo.proc.dist import is_dist +import torch.distributed from torch import nn +import torch + + +class StorageBank(nn.Module): + def __init__(self): + super().__init__() + self.offset = 0 + self.sizes = {} + + def register(self, name, dim, k, dtype=None): + if dim <= 0: + bank = (torch.ones(k) + float('inf')).to(dtype=dtype) + else: + bank = (torch.ones(k, dim) + float('inf')).to(dtype=dtype) + + self.register_buffer(name, bank) + self.sizes[name] = bank.shape[0] + + def __getitem__(self, item): + return self.__getattr__(item) + + @torch.no_grad() + def scatter(self, name, value, index): + value = value.detach() + value = gather(value) + if isinstance(index, torch.Tensor): + index = gather(index) + self[name][index] = value class MemoryBank(nn.Module): - def __init__(self, k): + def __init__(self): super().__init__() self.offset = 0 - self.k = k self.offsets = {} self.sizes = {} - def register(self, name, dim): + def register(self, name, dim, k, dtype=None): if dim <= 0: - bank = torch.rand(self.k) + bank = torch.rand(k, dtype=dtype) else: - bank = torch.rand(self.k, dim) + bank = torch.rand(k, dim, dtype=dtype) self.register_buffer(name, bank) self.offsets[name] = 0 - self.sizes[name] = bank.shape[0] + self.sizes[name] = k + + def __setitem__(self, key, value): + self.__setattr__(key, value) def __getitem__(self, item): return self.__getattr__(item) @@ -39,3 +70,56 @@ def push(self, name, value): value = value[:batch_size] self[name][ptr:ptr + batch_size] = value self.offsets[name] = (ptr + batch_size) % k + + +@torch.no_grad() +def batch_shuffle_ddp(x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + if not is_dist(): + return x, torch.arange(len(x)) + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + +@torch.no_grad() +def batch_unshuffle_ddp(x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + if not is_dist(): + return x + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] From cb9005f1a880da9a9c31f7118fc37268bc08a6d3 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:43:44 +0800 Subject: [PATCH 20/99] Support iter DatasetBuilder --- src/lumo/data/builder.py | 43 +++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/src/lumo/data/builder.py b/src/lumo/data/builder.py index 32f5722..7893a07 100644 --- a/src/lumo/data/builder.py +++ b/src/lumo/data/builder.py @@ -1,6 +1,7 @@ import copy import warnings from functools import partial +from itertools import cycle from pprint import pformat from copy import copy as builtin_copy from typing import Callable, NewType, Dict, Any, Iterable, Sequence @@ -129,27 +130,49 @@ def __iter__(self): for key, outkeys in self._outs.items(): if key == '::idx::': - warnings.warn(f'iter does not support idx, will skip {outkeys}') - continue + source = enumerate(cycle(range(1))) + else: + source = self._data[key] + # for outkey in outkeys: + ipt_transform = self._transforms.get(key, None) - source = self._data[key] - for outkey in outkeys: - self._iter_cache[outkey] = iter(source) + self._iter_cache[key] = iter(source), ipt_transform return self def __next__(self): if len(self._iter_cache) == 0: raise StopIteration() try: - outputs = {k: next(v) for k, v in self._iter_cache.items()} - if self.mode != 'zip': + inputs_sample = {} + for (key, (ipt_iter, ipt_transform)) in (self._iter_cache.items()): + if key == '::idx::': + ipt = next(ipt_iter)[0] + else: + ipt = next(ipt_iter) + if ipt_transform is not None: + ipt = ipt_transform(ipt) + inputs_sample[key] = ipt + + outputs = {} + for key, outkeys in self._outs.items(): + for outkey in outkeys: + opt = inputs_sample[key] + opt_transform = self._transforms.get(f'::{outkey}', None) + if opt_transform is not None: + opt = opt_transform(opt) + outputs[outkey] = opt + + if self.mode == 'chain': outputs = [outputs[outkey] for outkey in self._outkeys] + return outputs except StopIteration as e: self._iter_cache.clear() raise e def __getitem__(self, index): + if not self.sized: + raise TypeError('Source is not sizable. Use iterative method.') index = self.map_index(index) @@ -373,7 +396,8 @@ def add_input(self, name: str, source, transform: SingleValueTransform = None): assert name not in self._data, f'Source name {name} duplicated.' self._check_source(name, source) self._data[name] = source - self._transforms[name] = transform + self.set_input_transform(name, transform) + # self._transforms[name] = transform return self def add_input_transform(self, name: str, transform: SingleValueTransform = None): @@ -399,7 +423,8 @@ def add_output(self, name: str, outkey: str, transform: SingleValueTransform = N outkeys.append(outkey) self._outkeys.append(outkey) - self._transforms[f'::{outkey}'] = transform + self.set_output_transform(outkey, transform) + # self._transforms[f'::{outkey}'] = transform return self def add_output_transform(self, outkey: str, transform: SingleValueTransform = None): From 34f29829dbb9d033be5c4bddc0afb9548c05e0f4 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:44:10 +0800 Subject: [PATCH 21/99] Add loader type check when get dataset of dataloader --- src/lumo/data/datamodule.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py index b660296..eae4d83 100644 --- a/src/lumo/data/datamodule.py +++ b/src/lumo/data/datamodule.py @@ -45,7 +45,7 @@ def _parse_dataset(loader): return loader.dataset elif isinstance(loader, DataLoaderSide): return loader.dataset - return None + raise NotImplementedError(type(loader)) @property def train_dataset(self): @@ -96,7 +96,6 @@ def __getitem__(self, key): @overload def regist_dataloader(self, train=None, test=None, val=None, **kwargs): - """ Registers the given dataloaders under the given keys. @@ -106,7 +105,6 @@ def regist_dataloader(self, train=None, test=None, val=None, **kwargs): val: A DataLoaderType object for the validation set. **kwargs: A DataLoaderType object for other stage """ - ... def regist_dataloader(self, **kwargs: dict): for k, v in kwargs.items(): @@ -120,7 +118,7 @@ def regist_dataloader_with_stage(self, stage: TrainStage, dl: DataLoaderType): stage: A TrainStage object. dl: A DataLoaderType object. """ - self.prop[stage.value] = dl + self.regist_dataloader(**{stage.value: dl}) def idataloader(self, params: ParamsType = None, stage: TrainStage = None): """ From b689a8b97c8b8588b1721d66b65b9107a4a615c5 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:44:31 +0800 Subject: [PATCH 22/99] Reformat code --- src/lumo/trainer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/trainer/base.py b/src/lumo/trainer/base.py index cac7662..02687ce 100644 --- a/src/lumo/trainer/base.py +++ b/src/lumo/trainer/base.py @@ -214,7 +214,7 @@ def __setattr__(self, name, value): elif callable(getattr(value, "state_dict", None)) and callable(getattr(value, "load_state_dict", None)): type_name = 'others' else: - super().__setattr__(name, value) + # super().__setattr__(name, value) return # if name in self.__dict__: TODO workaround multi-gpu error: Expected to mark a variable ready only once From 98c10b8b3b461dcb8338ecb6f7e15908592c452d Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:44:44 +0800 Subject: [PATCH 23/99] Fix devices retrieval bug --- src/lumo/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index c465ad1..3f7c6f8 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -266,7 +266,8 @@ def trainer_state(self): @property def devices(self) -> Dict[str, torch.device]: - return self._state_dicts['devices'] + # return self._state_dicts['devices'] + return {key: self[key] for key in self._state_dicts['devices']} @property def model_dict(self) -> Dict[str, nn.Module]: @@ -705,6 +706,7 @@ def state_dict(self): 'others': self.other_state_dict(wrap=False), 'thtensor': self.torch_tensor, 'nptensor': self.numpy_tensor, + 'devices': self.devices, } return res From 47b3180f95c0df9d22e1a6e5b19e0d621553119e Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:54:54 +0800 Subject: [PATCH 24/99] Fix load_function --- src/lumo/trainer/trainer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 3f7c6f8..4fae384 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -310,10 +310,12 @@ def val_dataloader(self) -> Optional[DataLoaderType]: def device(self): return self.accelerate.device - def _load_fun_state_dict(self, src: dict, tgt: dict): - for k, v in tgt.items(): - if k in src: - v.load_state_dict(src[k]) + def _load_fun_state_dict(self, src: dict): + for k, v in src.items(): + if self._rev_index.get(k, None) is not None: + self[k].load_state_dict(v) + # if k in src: + # v.load_state_dict(src[k]) def regist_dataloader(self, dataloader: DataLoader, stage: TrainStage): self.datamodule.regist_dataloader_with_stage(stage, dataloader) @@ -352,8 +354,7 @@ def save_state_dict(self, name='latest.pth', dirpath=None, only_main=True): pre, ext = os.path.splitext(name) name = f'{pre}-{self.local_rank}{ext}' if dirpath is None: - - fn = self.exp.state_dict_dir + fn = os.path.join(self.exp.state_dict_dir, name) else: fn = os.path.join(dirpath, name) torch.save(self.state_dict(), fn) @@ -366,9 +367,10 @@ def load_state_dict(self, state_dict: dict): for k, v in state_dict.items(): if k in _sub: - self._load_fun_state_dict(v, self._state_dicts[k]) + self._load_fun_state_dict(v) else: - self._state_dicts[k] = v + for kk, vv in v.items(): + self[kk] = vv return def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapping]] = None, @@ -706,7 +708,7 @@ def state_dict(self): 'others': self.other_state_dict(wrap=False), 'thtensor': self.torch_tensor, 'nptensor': self.numpy_tensor, - 'devices': self.devices, + # 'devices': self.devices, } return res From fe4c3bc613fbc6fa907a14e6bf15e460843fc3f1 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:55:19 +0800 Subject: [PATCH 25/99] Coverage rate up to 81% powered by chatGPT --- tests/contrib/test_functional.py | 2 + tests/contrib/test_module.py | 18 ++++++ tests/core/test_interp.py | 72 ++++++++++++++++++++++ tests/core/test_meter.py | 49 ++++++++++++++- tests/proc/test_dist.py | 24 ++++++++ tests/proc/test_proc.py | 47 ++++++++++++++ tests/trainer/test_builder.py | 94 +++++++++++++++++++++++++--- tests/trainer/test_trainer.py | 40 +++++++++++- tests/utils/test_io.py | 101 +++++++++++++++++++++++++++++++ 9 files changed, 437 insertions(+), 10 deletions(-) create mode 100644 tests/proc/test_dist.py create mode 100644 tests/proc/test_proc.py create mode 100644 tests/utils/test_io.py diff --git a/tests/contrib/test_functional.py b/tests/contrib/test_functional.py index 0c79233..d40ae84 100644 --- a/tests/contrib/test_functional.py +++ b/tests/contrib/test_functional.py @@ -219,3 +219,5 @@ def test_contrastive_loss(): cs = InstanceLoss(4, 0.7, 'cpu') assert (cs.forward(a, b) - contrastive_loss(a, b, temperature=0.7, inbatch_neg=True)) ** 2 < 1e-10 + + \ No newline at end of file diff --git a/tests/contrib/test_module.py b/tests/contrib/test_module.py index e69de29..92e8c66 100644 --- a/tests/contrib/test_module.py +++ b/tests/contrib/test_module.py @@ -0,0 +1,18 @@ +from lumo.contrib.module.memoty_bank import MemoryBank +from accelerate import Accelerator +import torch +import pytest + + +def test_memory_bank(): + bank = MemoryBank() + bank.register('test', 32, 512) + res = [torch.rand(128, 32) for i in range(8)] + acce = Accelerator() + for r in res: + bank.push('test', r) + + assert (bank['test'] == torch.cat(res[4:], dim=0)).all() + bank.requires_grad_(False) + # assert (bank['test'][:128] == res[-1]).all() + # assert (bank['test'][128:] == torch.cat(res[1:4], dim=0)).all() diff --git a/tests/core/test_interp.py b/tests/core/test_interp.py index a1535c6..671338f 100644 --- a/tests/core/test_interp.py +++ b/tests/core/test_interp.py @@ -1,4 +1,5 @@ from lumo.core.interp import * +import numpy as np from torch.optim.sgd import SGD from torch import nn from lumo import BaseParams @@ -58,3 +59,74 @@ def test_scheduler_is_attr(): assert abs(phcos.get(9.99) - 1) < 1e-5 assert phcos.get(10) == 0 assert phcos.get(15) == 0.5 + + +def test_period_linear(): + start = 0 + end = 10 + period = 5 + left = 0 + constant = False + cur = 2 + + expected = 4.0 + result = PeriodLinear.interp(cur, start, end, left, period, constant) + assert result == expected + + +def test_power_decay(): + start = 1 + decay_steps = 5 + decay_rate = 0.5 + end = None + cur = 10 + + expected = 0.25 + power_decay = PowerDecay(start, decay_steps, decay_rate, end) + result = power_decay(cur) + assert result == expected + + +def test_power_decay2(): + start = 1 + schedules = [5, 10] + gammas = [0.5, 0.2] + cur = 12 + + expected = 0.1 + power_decay2 = PowerDecay2(start, schedules, gammas) + result = power_decay2(cur) + assert result == expected + + +# def test_ABCContinuous(): +# # Test ABCContinuous class +# abc = ABCContinuous(start=1, end=2, left=0, right=10) +# assert abc(0) == 1 +# assert abc(10) == 2 +# assert abc(5) == abc.interp(5, start=1, end=2, left=0, right=10) + + +def test_Exp(): + # Test Exp class + exp = Exp(start=1, end=2, left=0, right=10) + assert np.isclose(exp(0), 1) + assert np.isclose(exp(10), 2) + assert np.isclose(exp(5), 1.078716025, rtol=1e-5) + assert np.isclose(exp(8), 1.366531851) + assert np.isclose(exp(9.5), 1.7784638857) + + +def test_Log(): + # Test Log class + log = Log(start=1, end=2, left=0, right=10) + assert np.isclose(log(0), 1) + assert np.isclose(log(10), 2) + assert np.isclose(log(5), 1.921283974) + + +def test_Constant(): + # Test Constant class + const = Constant(value=0.5) + assert const(0) == 0.5 + assert const(10) == 0.5 diff --git a/tests/core/test_meter.py b/tests/core/test_meter.py index 9f42f72..aec47f0 100644 --- a/tests/core/test_meter.py +++ b/tests/core/test_meter.py @@ -1,6 +1,53 @@ +from collections import OrderedDict + import pytest -from lumo.core.meter import ReduceItem +from lumo.core.meter import ReduceItem, Meter + + +def test_meter(): + m = Meter() + m['loss'] = 0.5 + m['accuracy'] = 0.8 + + # test __getitem__ method + assert m['loss'] == 0.5 + + # test __setitem__ method + m['loss'] = 0.2 + assert m['loss'] == 0.2 + + # test __repr__ method + assert repr(m) == 'loss: 0.2 | accuracy: 0.8' + + # test keys method + assert set(m.keys()) == {'loss', 'accuracy'} + + # test todict method + assert m.todict() == {'loss': 0.2, 'accuracy': 0.8} + + # test sorted method + sorted_m = m.sorted() + assert isinstance(sorted_m, Meter) + assert set(sorted_m.keys()) == {'accuracy', 'loss'} + assert repr(sorted_m) == 'accuracy: 0.8 | loss: 0.2' + + # test update method + m.update({'loss': 0.1, 'precision': 0.9}) + assert set(m.keys()) == {'loss', 'accuracy', 'precision'} + assert m.todict() == {'loss': 0.1, 'accuracy': 0.8, 'precision': 0.9} + + # test from_dict method + m2 = Meter.from_dict(OrderedDict([('loss', 0.1), ('accuracy', 0.8), ('precision', 0.9)])) + assert set(m2.keys()) == {'loss', 'accuracy', 'precision'} + assert m2.todict() == {'loss': 0.1, 'accuracy': 0.8, 'precision': 0.9} + + # test scalar_items method + m3 = Meter() + m3['loss'] = 0.5 + m3['accuracy'] = '80%' + m3['precision'] = [0.9, 0.8] + assert set(m3.scalar_items()) == {('loss', 0.5), ('accuracy', '80%')} def test_avg_item_mean(): diff --git a/tests/proc/test_dist.py b/tests/proc/test_dist.py new file mode 100644 index 0000000..1cda926 --- /dev/null +++ b/tests/proc/test_dist.py @@ -0,0 +1,24 @@ +from lumo.proc.dist import * + + +def test_local_rank(monkeypatch): + monkeypatch.setenv('LOCAL_RANK', '0') + assert local_rank() == 0 + + monkeypatch.setenv('LOCAL_RANK', '1') + assert local_rank() == 1 + + +def test_world_size(monkeypatch): + monkeypatch.setenv('WORLD_SIZE', '4') + assert world_size() == 4 + + +def test_is_dist(monkeypatch): + monkeypatch.setenv('LOCAL_RANK', '0') + assert is_dist() == True + + +def test_is_main(monkeypatch): + monkeypatch.setenv('LOCAL_RANK', '0') + assert is_main() == True diff --git a/tests/proc/test_proc.py b/tests/proc/test_proc.py new file mode 100644 index 0000000..a620ea3 --- /dev/null +++ b/tests/proc/test_proc.py @@ -0,0 +1,47 @@ +from lumo.proc.path import * + + +def test_home(): + assert home() == os.path.expanduser("~") + + +# def test_cache_dir(): +# CACHE_ROOT = glob.get('cache_dir', None) +# expected = CACHE_ROOT or os.path.join(home(), '.lumo/cache') +# assert cache_dir() == expected + + +def test_libhome(): + LIBHOME = glob.get('home', None) + expected = LIBHOME or os.path.join(home(), '.lumo') + assert libhome() == expected + + +def test_exproot(): + EXP_ROOT = glob.get('exp_root', None) + expected = EXP_ROOT or os.path.join(libhome(), 'experiments') + assert exproot() == expected + + +def test_progressroot(): + PROGRESS_ROOT = glob.get('progress_root', None) + expected = PROGRESS_ROOT or os.path.join(libhome(), 'progress') + assert progressroot() == expected + + +def test_blobroot(): + BLOB_ROOT = glob.get('blob_root', None) + expected = BLOB_ROOT or os.path.join(libhome(), 'blob') + assert blobroot() == expected + + +def test_metricroot(): + METRIC_ROOT = glob.get('metric_root', None) + expected = METRIC_ROOT or os.path.join(libhome(), 'metrics') + assert metricroot() == expected + + +def test_local_dir(): + # it's difficult to test this function without a specific context + # since it depends on git_dir(), which is not included in the provided code + pass diff --git a/tests/trainer/test_builder.py b/tests/trainer/test_builder.py index 5e01ff1..071fd1f 100644 --- a/tests/trainer/test_builder.py +++ b/tests/trainer/test_builder.py @@ -1,4 +1,4 @@ -from lumo import DatasetBuilder, DataLoaderSide +from lumo import DatasetBuilder, DataLoaderSide, DataModule, ParamsType, TrainStage def global_check(dic): @@ -10,18 +10,56 @@ def global_check(dic): def create_dataset_builder(): builder = ( DatasetBuilder() + .add_idx('id') .add_input(name='xs', source=range(1000)) + .add_input(name='axs', source=range(1000), transform=lambda x: x - 1) .add_input(name='ys', source=range(1, 1001)) - .add_output(name='xs', outkey='xs1') + .add_output(name='xs', outkey='xs1', transform=lambda x: x + 1) .add_output(name='xs', outkey='xs2') + .add_output(name='axs', outkey='xs3') .add_output(name='ys', outkey='ys1') - .set_output_transform('xs1', lambda x: x + 1) .set_output_transform('ys1', lambda x: x - 1) .add_global_transform(global_check) ) return builder +def test_iter_data(): + builder = ( + DatasetBuilder() + .add_idx('id') + .add_input('xs', iter(range(20))) + .add_input('ys', iter(range(20))) + .add_output('xs', 'xs1') + .add_output('xs', 'xs2', transform=lambda x: x + 1) + .add_output('ys', 'ys') + ) + try: + builder[0] + assert False + except TypeError: + assert True + + try: + len(builder) + assert False + except TypeError: + assert True + + for i, sample in enumerate(builder): + assert isinstance(sample, dict) + assert sample['xs1'] == i + assert sample['xs2'] == i + 1 + assert sample['ys'] == i + assert sample['id'] == i + + builder.chain() + for i, (xs1, xs2, ys) in enumerate(builder): + assert xs1 == i + assert xs2 == i + 1 + assert ys == i + + def test_builder_base(): builder = create_dataset_builder() @@ -39,12 +77,21 @@ def test_builder_base(): assert len(builder) == 1000 - sub_builder = builder.subset(range(500), copy=True) + sub_builder = builder.subset(range(20), copy=True) assert len(builder) == 1000 - assert len(sub_builder) == 500 - assert sub_builder[499]['xs1'] == 500 - assert sub_builder[499]['ys1'] == 499 - assert sub_builder[499]['xs2'] == 499 + assert builder != sub_builder + new_builder = builder.subset(range(500)) + assert len(builder) == 500 + assert builder == new_builder + + assert len(sub_builder) == 20 + + for i, sample in enumerate(sub_builder): + assert sample['id'] == i + assert sample['xs1'] == i + 1 + assert sample['ys1'] == i + assert sample['xs2'] == i + assert sample['xs3'] == i - 1 dic = sub_builder.inputs assert 'xs' in dic @@ -57,6 +104,14 @@ def test_builder_base(): str(sub_builder) + sub_builder.zip() + assert isinstance(sub_builder[0], dict) + sub_builder.chain() + assert isinstance(sub_builder[0], list) + sub_builder.item() + print(sub_builder[0]) + # assert isinstance(sub_builder[0], dict) + def test_side(): sup = create_dataset_builder() @@ -81,3 +136,26 @@ def test_side(): assert 'xs2' in sup assert 'ys1' in sup assert un['xs1'].shape[0] == 32 + + +class MyDataModule(DataModule): + + def idataloader(self, params: ParamsType = None, stage: TrainStage = None): + super().idataloader(params, stage) + sup = create_dataset_builder() + un = create_dataset_builder() + + dl = ( + DataLoaderSide() + .add('sup', sup.DataLoader(batch_size=128, drop_last=True), cycle=True) + .add('un', un.DataLoader(batch_size=32, drop_last=True)) + .zip() + ) + self.regist_dataloader_with_stage(stage, dl) + + +def test_dm_dataloader(): + dm = MyDataModule() + loader = dm.train_dataloader + assert dm.train_dataset == loader.dataset + assert isinstance(dm.train_dataloader, DataLoaderSide) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0af44cd..bcb7b4e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,5 +1,7 @@ from typing import Union, Optional, Sequence, Mapping, Any +import numpy as np + from lumo.proc.config import debug_mode from lumo.utils.repository import git_dir import os @@ -143,14 +145,19 @@ def test_trainer(): params.epoch = 2 debug_mode() + dm = MyDataModule() # glob['HOOK_FINALREPORT'] = False - trainer = CBTrainer(params, dm=MyDataModule()) + trainer = CBTrainer(params, dm=dm) trainer.train() trainer.test() trainer.evaluate() trainer.logger.info(trainer.lf.functions) trainer.exp.end() + assert dm.train_dataset == dm._parse_dataset(dm.train_dataloader) + assert dm.val_dataset == dm._parse_dataset(dm.val_dataloader) + assert dm.test_dataset == dm._parse_dataset(dm.test_dataloader) + # test trainer experiment exp = trainer.exp assert exp.exp_root == os.path.join(glob['exp_root'], trainer.generate_exp_name()) @@ -164,5 +171,36 @@ def test_trainer(): raise AssertionError(str(trainer.callback_function - trainer.lf.functions)) +def test_trainer_state_dict(): + trainer = Trainer(TrainerParams()) + device_a = trainer.device_a = torch.device('cpu') + ndarray_a = trainer.ndarray_a = np.array([1, 2, 3]) + tensor_a = trainer.tensor_a = torch.tensor([1, 2, 3]) + module = trainer.module = nn.Linear(10, 10) + optim_a = trainer.optim_a = TrainerParams.OPTIM.create_optim('SGD', lr=0.9).build(trainer.module.parameters()) + + state_dict = trainer.state_dict() + # assert state_dict['devices']['device_a'] == trainer.device_a + assert state_dict['optims']['optim_a'] == trainer.optim_a.state_dict() + assert all([(i == j).all() + for i, j in zip(state_dict['models']['module'].values(), trainer.module.state_dict().values())]) + assert (state_dict['thtensor']['tensor_a'] == trainer.tensor_a).all() + assert (state_dict['nptensor']['ndarray_a'] == trainer.ndarray_a).all() + + fn = trainer.save_state_dict() + trainer.ndarray_a = np.array([3, 2, 1]) + trainer.tensor_a = torch.tensor([3, 2, 1]) + trainer.module = nn.Linear(10, 10) + trainer.optim_a = TrainerParams.OPTIM.create_optim('SGD', lr=0.9).build(trainer.module.parameters()) + + + trainer.load_state_dict(torch.load(fn, map_location='cpu')) + assert state_dict['optims']['optim_a'] == optim_a.state_dict() + assert all([(i == j).all() + for i, j in zip(state_dict['models']['module'].values(), module.state_dict().values())]) + assert (state_dict['thtensor']['tensor_a'] == tensor_a).all() + assert (state_dict['nptensor']['ndarray_a'] == ndarray_a).all() + + if __name__ == '__main__': test_trainer() diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py new file mode 100644 index 0000000..fc52be6 --- /dev/null +++ b/tests/utils/test_io.py @@ -0,0 +1,101 @@ +import yaml +from lumo.utils.safe_io import * + + +def test_dump_json(tmpdir): + # Test that dump_json creates a valid JSON file + obj = {"a": 1, "b": 2} + fn = os.path.join(str(tmpdir), "test.json") + dump_json(obj, fn) + with open(fn, "r") as f: + data = json.load(f) + assert data == obj + + +def test_dump_yaml(tmpdir): + # Test that dump_yaml creates a valid YAML file + obj = {"a": 1, "b": 2} + fn = os.path.join(str(tmpdir), "test.yaml") + dump_yaml(obj, fn) + with open(fn, "r") as f: + data = yaml.safe_load(f) + assert data == obj + + +def test_dump_state_dict(tmpdir): + # Test that dump_state_dict creates a valid state dict file + obj = {"a": torch.randn(3, 3)} + fn = os.path.join(str(tmpdir), "test.pt") + dump_state_dict(obj, fn) + data = torch.load(fn) + assert (data['a'] == obj['a']).all() + + +def test_load_json(tmpdir): + # Test that load_json reads a valid JSON file + obj = {"a": 1, "b": 2} + fn = os.path.join(str(tmpdir), "test.json") + with open(fn, "w") as f: + json.dump(obj, f) + data = load_json(fn) + assert data == obj + + +def test_load_yaml(tmpdir): + # Test that load_yaml reads a valid YAML file + obj = {"a": 1, "b": 2} + fn = os.path.join(str(tmpdir), "test.yaml") + with open(fn, "w") as f: + yaml.safe_dump(obj, f) + data = load_yaml(fn) + assert data == obj + + +def test_load_state_dict(tmpdir): + # Test that load_state_dict reads a valid state dict file + obj = {"a": torch.randn(3, 3)} + fn = os.path.join(str(tmpdir), "test.pt") + torch.save(obj, fn) + data = load_state_dict(fn) + assert (data['a'] == obj['a']).all() + + +def test_dump_text(tmpdir): + # Test that dump_text creates a valid text file + string = "hello\nworld" + fn = os.path.join(str(tmpdir), "test.txt") + dump_text(string, fn) + with open(fn, "r") as f: + data = f.read() + assert data == string + + +def test_load_text(tmpdir): + # Test that load_text reads a valid text file + string = "hello\nworld" + fn = os.path.join(str(tmpdir), "test.txt") + with open(fn, "w") as f: + f.write(string) + data = load_text(fn) + assert data == string + + +def test_dump_pkl(tmpdir): + # Test that dump_pkl creates a valid pickle file + obj = {"a": 1, "b": 2} + fn = os.path.join(str(tmpdir), "test.pkl") + dump_pkl(obj, fn) + data = load_pkl(fn) + assert data == obj + + +def test_cached(tmpdir): + # Test that cached context manager works correctly + with cached(os.path.join(str(tmpdir), "test.txt")) as cache_fn: + # Write some data to the cache file + with open(cache_fn, "w") as f: + f.write("hello") + # Check that the cache file exists + assert os.path.isfile(cache_fn) + # Check that the cache file is deleted after exiting the context manager + assert not os.path.isfile(cache_fn) From e4644b2fd47251c15f33366dbb23a09394405683 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 10:55:29 +0800 Subject: [PATCH 26/99] Add some ignore items --- pyproject.toml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26c377c..e013350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,10 +76,22 @@ exclude_lines = [ "if TYPE_CHECKING:", # some tricky object that can not well tested. "except ImportError", - "class RecordAbort", "pass", + "return None", + "break", # + + # Hard to test + "class RecordAbort", + "class GitCommit", "def summary_experiment", "def plot", + "if torch.cuda.is_available()", # ignore cuda + "if is_dist():", # ignore distribution + + # Deprecated method: + "def add_input_transform", + "def add_output_transform", + "raise StopIteration" ] [tool.coverage.html] From fb2ccc3d4d54a95a2282f64610f384527d14919c Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:10:00 +0800 Subject: [PATCH 27/99] remove agent.py for no usage --- src/lumo/{exp/agent.py => sketch/wait_pid_stop.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/lumo/{exp/agent.py => sketch/wait_pid_stop.py} (100%) diff --git a/src/lumo/exp/agent.py b/src/lumo/sketch/wait_pid_stop.py similarity index 100% rename from src/lumo/exp/agent.py rename to src/lumo/sketch/wait_pid_stop.py From cd8e186955ab9f5bca495367e57de00760d0d266 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:21:05 +0800 Subject: [PATCH 28/99] Add docstring powered by chatGPT --- src/lumo/exp/base.py | 7 +++++- src/lumo/exp/experiment.py | 8 +++---- src/lumo/exp/exphook.py | 48 +++++++++++++++++++++++++++++--------- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/lumo/exp/base.py b/src/lumo/exp/base.py index 257d925..2c90810 100644 --- a/src/lumo/exp/base.py +++ b/src/lumo/exp/base.py @@ -1,7 +1,12 @@ from lumo.proc import glob -class ExpHook: +class BaseExpHook: + """A base class of hook for experiments that can be registered with an experiment. + + Please Use Exphook in exp.exphook for better typehint. + + """ name = None # type: str configs = {} diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 1a90388..3aeaad8 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -14,7 +14,7 @@ from lumo.utils import safe_io as io from lumo.utils.fmt import can_be_filename from lumo.utils.logger import Logger -from .base import ExpHook +from .base import BaseExpHook from ..proc.pid import pid_hash, runtime_pid_obj @@ -337,7 +337,7 @@ def start(self): return self.initial() self.dump_info('start', True) - for hook in self._hooks.values(): # type: ExpHook + for hook in self._hooks.values(): # type: BaseExpHook hook.on_start(self) return self @@ -349,7 +349,7 @@ def end(self, end_code=0, *args, **extra): return self.dump_progress(1) self.dump_info('end', True) - for hook in self._hooks.values(): # type: ExpHook + for hook in self._hooks.values(): # type: BaseExpHook hook.on_end(self, end_code=end_code, *args, **extra) return self @@ -415,7 +415,7 @@ def enable_properties(self) -> set: return set(self._prop.keys()) @call_on_main_process_wrap - def set_hook(self, hook: ExpHook): + def set_hook(self, hook: BaseExpHook): hook.regist(self) if not glob.get(hook.config_name, True): self.dump_info(hook.name, { diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index 41969e3..ae6a378 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -15,10 +15,12 @@ from lumo.utils.exithook import wrap_before from lumo.utils.fmt import strftime, indent_print from . import Experiment -from .base import ExpHook as BaseExpHook +from .base import BaseExpHook as BaseExpHook class ExpHook(BaseExpHook): + """A base class of hook for experiments that can be registered with an experiment.""" + def regist(self, exp: Experiment): self.exp = exp @@ -32,6 +34,11 @@ def on_newpath(self, exp: Experiment, *args, **kwargs): pass class LastCmd(ExpHook): + """A hook to save the last command executed in an experiment. + + This hook saves the last command executed in an experiment to a shell script file in a specified directory. The saved + file can be used to re-run the experiment with the same command. + """ configs = {'HOOK_LASTCMD_DIR': os.getcwd()} def on_start(self, exp: Experiment, *args, **kwargs): @@ -62,6 +69,8 @@ def on_start(self, exp: Experiment, *args, **kwargs): class Diary(ExpHook): + """A hook for logging experiment information to a diary file.""" + def on_start(self, exp: Experiment, *args, **kwargs): super().on_start(exp, *args, **kwargs) with open(exp.root_file(f'{strftime("%y%m%d")}.log', 'diary'), 'a') as w: @@ -69,6 +78,9 @@ def on_start(self, exp: Experiment, *args, **kwargs): class RecordAbort(ExpHook): + """A hook to record and handle experiment aborts. + """ + def regist(self, exp: Experiment): super().regist(exp) wrap_before(self.exc_end) @@ -176,21 +188,35 @@ def on_start(self, exp: Experiment, *args, **kwargs): class LockFile(ExpHook): + """A class for locking dependencies for an experiment. + Locks the specified dependencies for the experiment and saves them to a file. + """ def on_start(self, exp: Experiment, *args, **kwargs): - exp.dump_info('lock', get_lock('lumo', - 'numpy', - 'joblib', - 'psutil', - 'decorator', - 'torch', - 'numpy', - 'accelerate', - 'hydra', - 'omegaconf', )) + basic = get_lock('lumo', + 'numpy', + 'joblib', + 'psutil', + 'decorator', + 'torch', + 'numpy', + 'accelerate', + 'hydra', + 'omegaconf', ) + if basic['torch'] is not None: + import torch + if torch.cuda.is_available(): + basic['torch.version.cuda'] = torch.version.cuda + + exp.dump_info('lock', basic) class FinalReport(ExpHook): + """A class for generating a final report for an experiment. + + Prints the experiment's properties, tags, paths, and execute command. + """ + def on_end(self, exp: Experiment, end_code=0, *args, **kwargs): # if end_code == 0: print('-----------------------------------') From c6052f977af94243b43603ecfaffc86df44f0c5f Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:21:13 +0800 Subject: [PATCH 29/99] Add docstring --- src/lumo/analyse/condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lumo/analyse/condition.py b/src/lumo/analyse/condition.py index 10d21d4..08ce2b2 100644 --- a/src/lumo/analyse/condition.py +++ b/src/lumo/analyse/condition.py @@ -121,7 +121,8 @@ def first(self, value): def filter_by_condition(df: DataFrame, *condition: Compare) -> DataFrame: """ - 简化版的 pipeline,仅用于等式和不等式筛选以及额外支持的 in/not_in 等功能 + fdf = filter_by_condition(df,C['contition'] == True, C['c2'] > 1, ... ) + Args: df: padnas.DataFrame instance *condition: list of `~Compare` instance From 06c8eaf2757d6ca9bfa16db5e871d1540db531ac Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:21:20 +0800 Subject: [PATCH 30/99] Remove unused file --- src/lumo/exp/README.md | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 src/lumo/exp/README.md diff --git a/src/lumo/exp/README.md b/src/lumo/exp/README.md deleted file mode 100644 index e57d6fe..0000000 --- a/src/lumo/exp/README.md +++ /dev/null @@ -1,14 +0,0 @@ -- 记录实验,负责实验的组织、目录管理、元数据记录 - -``` -from lumo_experiment import Experiment - -Experiment('name') -``` - -Experiment 的意义: - -- 为每个实验分配唯一存储空间(通过 exp_name/test_name ) -- 为每个实验保留回溯可能(通过 git) -- 为每个实验进行标注,方便(可能存在的)检索 -- 其他状态动态记录 \ No newline at end of file From 9f989cf6465a0beda17dbaf29f379a590047d897 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:21:59 +0800 Subject: [PATCH 31/99] Modified get_lock for common usage --- src/lumo/proc/dependency.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lumo/proc/dependency.py b/src/lumo/proc/dependency.py index a67bcc6..4f85fda 100644 --- a/src/lumo/proc/dependency.py +++ b/src/lumo/proc/dependency.py @@ -30,7 +30,6 @@ def get_lock(*others): """ res = {} - res['lumo'] = lumo_version res.update({k: v for k, v in Version.__dict__.items() if not k.startswith('__')}) for lib in others: From 9e04a6ed2aeb78bd943601cef591c7b839f9cd89 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:29:55 +0800 Subject: [PATCH 32/99] Add docstring powered by chatGPT --- src/lumo/proc/config.py | 44 ++++++++++++++++++++++++++++++++++++++++- src/lumo/proc/path.py | 22 ++++++++++++++++++++- src/lumo/proc/pid.py | 24 +++++++++++++++++++++- 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py index e5b9164..9733c8a 100644 --- a/src/lumo/proc/config.py +++ b/src/lumo/proc/config.py @@ -4,7 +4,6 @@ __all__ = ['debug_mode', 'glob', 'global_config_path', 'local_config_path'] import tempfile -from typing import overload GLOBAL_DEFAULT = { 'home': os.path.expanduser("~/.lumo/"), @@ -14,10 +13,25 @@ def global_config_path(): + """ + Returns the path to the global configuration file. + + Returns: + str: The path to the global configuration file. + + Notes: + The relative path of global config path should never change (~/.lumorc.json) + """ return os.path.expanduser("~/.lumorc.json") def local_config_path(): + """ + Returns the path to the local configuration file. + + Returns: + str: The path to the local configuration file, if found. Otherwise, None. + """ from lumo.utils.repository import git_dir res = git_dir() if res: @@ -26,6 +40,19 @@ def local_config_path(): def get_config(path, default): + """ + Reads the configuration file at the given path or creates it if it doesn't exist. + + Args: + path (str): The path to the configuration file. + default (dict): The default configuration to use if the file doesn't exist. + + Returns: + dict: The configuration read from the file or the default configuration if the file doesn't exist. + + Raises: + Exception: If there was an error reading the configuration file. + """ if path is None: return default @@ -44,6 +71,12 @@ def get_config(path, default): def get_runtime_config(): + """ + Returns the runtime configuration by merging the global and local configurations. + + Returns: + dict: The merged runtime configuration. + """ glob_cfg = get_config(global_config_path(), GLOBAL_DEFAULT) local_cfg = get_config(local_config_path(), {}) cfg = GLOBAL_DEFAULT @@ -53,6 +86,15 @@ def get_runtime_config(): def debug_mode(base_dir=None, disable_git=True): + """Sets up global variables for debugging mode. + + Args: + base_dir (str, optional): The directory to create temporary directories in. Defaults to None. + disable_git (bool, optional): Whether to disable git hooks. Defaults to True. + + Returns: + None + """ glob['exp_root'] = tempfile.mkdtemp(dir=base_dir) glob['progress_root'] = tempfile.mkdtemp(dir=base_dir) glob['home'] = tempfile.mkdtemp(dir=base_dir) diff --git a/src/lumo/proc/path.py b/src/lumo/proc/path.py index acb45f3..e90618b 100644 --- a/src/lumo/proc/path.py +++ b/src/lumo/proc/path.py @@ -5,6 +5,16 @@ def home(): + """ + Returns the home directory of the current user. + + This function returns the home directory of the current user, which is + different from the home directory of the library. If you need to access + the library home directory, use `libhome()` instead. + + Returns: + str: The home directory of the current user. + """ return os.path.expanduser("~") @@ -31,7 +41,17 @@ def cache_dir(): def libhome(): - """Library home to store configs. Default is `~/.lumo`""" + """ + Returns the library home directory. + + This function returns the library home directory, which is used to store + configuration files. By default, the library home directory is `~/.lumo`. + If a custom home directory is set in the `home` global variable, that + directory will be returned instead. + + Returns: + str: The library home directory. + """ LIBHOME = glob.get('home', None) if LIBHOME: return LIBHOME diff --git a/src/lumo/proc/pid.py b/src/lumo/proc/pid.py index 174be98..23ed726 100644 --- a/src/lumo/proc/pid.py +++ b/src/lumo/proc/pid.py @@ -1,10 +1,24 @@ +""" +Returns information about the specified process or the current process, and computes its hash value. +""" from psutil import Process -import sys from joblib import hash import os def runtime_pid_obj(pid=None): + """Returns a dictionary containing information about a process identified by the given PID. + + Args: + pid (int, optional): The PID of the process to get information about. If None, uses the PID of the current process. Defaults to None. + + Returns: + dict: A dictionary containing the following keys: + - pid (int): The process ID. + - pname (str): The name of the process. + - pstart (float): The process creation time, in seconds since the epoch. + - argv (list): A list of command-line arguments passed to the process. + """ if pid is None: pid = os.getpid() p = Process(pid) @@ -15,6 +29,14 @@ def runtime_pid_obj(pid=None): def pid_hash(pid_obj=None): + """Computes the hash of a process object. + + Args: + pid_obj (dict, optional): A dictionary containing information about a process. If None, uses the information about the current process. Defaults to None. + + Returns: + str: The hash of the process object. + """ if pid_obj is None: pid_obj = runtime_pid_obj() return hash(pid_obj) From f19182c9f1bb98d18aa1db5ae9ffc65e59b653ef Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 11:50:53 +0800 Subject: [PATCH 33/99] Add docstring powered by chatGPT; Reformat code --- src/lumo/trainer/base.py | 128 ++++++++++++++++++++------------- src/lumo/trainer/callbacks.py | 4 +- src/lumo/trainer/components.py | 18 ++--- src/lumo/trainer/factory.py | 34 +++++++++ src/lumo/trainer/rnd.py | 28 +++----- src/lumo/trainer/trainer.py | 1 - 6 files changed, 131 insertions(+), 82 deletions(-) diff --git a/src/lumo/trainer/base.py b/src/lumo/trainer/base.py index 02687ce..4e2e37e 100644 --- a/src/lumo/trainer/base.py +++ b/src/lumo/trainer/base.py @@ -1,15 +1,4 @@ """ - - 提供主要的训练流 - - 加载数据集 DataLoader - - train、test、eval - - 提供对训练流的控制、回调 - - callbacks - - 提供对训练中间的状态保存 - - metric: Meter - - checkpoints: Saver - - -trainer = Trainer() """ import inspect import os @@ -26,6 +15,25 @@ def _exit_hook_v0(exc_type, exc, tb, *args): + """Prints the traceback information when an unhandled exception occurs. + + Args: + exc_type (type): The type of the exception. + exc (Exception): The instance of the exception. + tb (traceback): The traceback object containing the call stack. + *args: Additional arguments to be passed to the function. + + Returns: + None + + Raises: + None + + This function is designed to be used as an exit hook with the `sys.excepthook` function. + It formats the traceback information and removes any lines related to the `_newfunc` function. + The resulting traceback is printed to the `sys.stderr` stream. + + """ import traceback res = traceback.format_exception(exc_type, exc, tb) res = [i for i in res if 'in _newfunc' not in i] @@ -33,6 +41,25 @@ def _exit_hook_v0(exc_type, exc, tb, *args): def _exit_hook(exc_type, exc, tb, *args): + """Prints an error traceback and displays it using the rich library. + + Args: + exc_type: Type of the exception that was raised. + exc: The exception instance that was raised. + tb: Traceback object that contains information about the exception. + *args: Optional additional arguments to be passed to _exit_hook_v0. + + Returns: + None. + + Raises: + Any exceptions that were not caught by _exit_hook_v0. + + Examples: + # Call _exit_hook with an exception + >>> _exit_hook(TypeError, 'test error', traceback, arg1, arg2) + + """ from rich.console import Console console = Console() _exit_hook_v0(exc_type, exc, tb, *args) @@ -54,40 +81,6 @@ def _exit_hook(exc_type, exc, tb, *args): console.print(traceback) -def wrapper(self, func, _call_set: list): - """ - 对每个 Trainer 的 _call_backs 类变量中定义的函数尝试绑定回调 - Args: - func: - _call_set: - - Returns: - - """ - - @wraps(func) - def _newfunc(*aargs, **kkwargs): - """执行前回调 on_begin() 、执行后回调 on_end()、执行异常则回调 on_exception() """ - for callback in _call_set: - callback.on_begin(self, func, self.params, *aargs, **kkwargs) - try: - _meter = func(*aargs, **kkwargs) - except BaseException as e: - _handles = [callback.on_exception(self, func, self.params, e, *aargs, **kkwargs) - for callback in _call_set] - - if any(_handles): - return None - else: - raise e - - for callback in _call_set: - callback.on_end(self, func, self.params, _meter, *aargs, **kkwargs) - return _meter - - return _newfunc - - init_function = ['icallbacks', 'imodels'] call_dependency = { 'train': init_function, @@ -95,6 +88,8 @@ def _newfunc(*aargs, **kkwargs): class _BaseTrainer: + """Base class for training neural network models. + """ __exp_name__ = None callback_function = {} @@ -122,8 +117,17 @@ def __new__(cls, *args, **kwargs): replace(_exit_hook) def init_wrapper(func): + """ + Wraps the train/test/eval functions to initialize in silence. + + Notes: + Before calling the train/test/eval functions, the `trainer.initialize` method is called, + and then the corresponding DataLoader for the stage is initialized through the `process_loader` method. + """ + @wraps(func) def inner(dm=None, params=None, *args, **kwargs): + """The inner function that wraps the train/test/eval function.""" init_fn = getattr(self, 'initialize', None) if init_fn is not None: init_fn() @@ -136,18 +140,21 @@ def inner(dm=None, params=None, *args, **kwargs): def cb_wrapper(func, call_set: list): """ - 对每个 Trainer 的 _call_backs 类变量中定义的函数尝试绑定回调 + Wraps the given function with callback functions. + Args: - func: - call_set: + func (function): The function to wrap. + call_set (list): A list of callback functions. Returns: - + A wrapped function. """ @wraps(func) def _newfunc(*aargs, **kkwargs): - """执行前回调 on_begin() 、执行后回调 on_end()、执行异常则回调 on_exception() """ + """ + Executes the callback functions before and after the given function and on exception. + """ # on_begin self._contexts.append(func.__name__) for callback in call_set: @@ -225,10 +232,14 @@ def __setattr__(self, name, value): @property def contexts(self): + """Get the name stack of function call contexts. + The first is the name of the Trainer class + """ return self._contexts @property def context(self): + """Get the name of the most recent function call context.""" return self._contexts[-1] # def __getattr__(self, name): @@ -245,11 +256,22 @@ def context(self): @classmethod def dirname(cls): + """Get the directory name of the file where the class is defined. + + Returns: + A string representing the directory name of the file where the class is defined. + """ file = inspect.getfile(cls) return os.path.basename(os.path.dirname(file)) @classmethod def filebasename(cls): + """Get the basename of the file where the class is defined. + + Returns: + A string representing the basename of the file where the class is defined. + If an exception occurs, returns 'builtin'. + """ try: file = inspect.getfile(cls) pre = os.path.splitext(os.path.basename(file))[0] @@ -259,6 +281,12 @@ def filebasename(cls): @classmethod def generate_exp_name(cls) -> str: + """Generate an experiment name based on the file basename and the class name. + + Returns: + A string representing the experiment name, formatted as '.'. + If '__exp_name__' is defined, it is used instead of the default class name with 'trainer' replaced by 'exp'. + """ pre = cls.filebasename() exp_name = cls.__exp_name__ diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 42d45e8..4125513 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -1,7 +1,6 @@ """ """ import inspect -import json import os import tempfile import time @@ -10,13 +9,14 @@ from datetime import datetime from functools import wraps from typing import NewType, Any, Optional, Dict, Union -from lumo.utils.memory_grab import DeviceMem + from torch.utils.data import DataLoader from lumo.core import Meter, MetricType, Record, TrainStage, wrap_result, ParamsType from lumo.data import DataModule from lumo.data.loader import summarize_loader, DataLoaderType from lumo.utils import fmt +from lumo.utils.memory_grab import DeviceMem from lumo.utils.screen import inlinetqdm from .trainer import Trainer from ..proc.dist import world_size diff --git a/src/lumo/trainer/components.py b/src/lumo/trainer/components.py index 17ff9e9..8422035 100644 --- a/src/lumo/trainer/components.py +++ b/src/lumo/trainer/components.py @@ -1,5 +1,3 @@ -from typing import NewType - import torch from lumo.core import Params @@ -8,6 +6,7 @@ class TrainerExperiment(SimpleExperiment): + """A class for helping manage an experiment by Trainer.""" @property def log_dir(self): @@ -40,18 +39,19 @@ def state_dict_dir(self): def dump_train_eidx(self, eidx, epoch: int): """ - Args: - eidx: start from 0, end at `epoch-1` - epoch: + Dumps the progress of the trainer. + + Args: + eidx (int): The index of the current epoch (starting from 0). + epoch (int): The total number of epochs to train for. """ self.dump_progress((eidx + 1) / epoch, update_from='trainer') -class ReimplementExperiment(TrainerExperiment): - pass - - class TrainerParams(Params): + """ + A class to hold parameters for trainer. + """ OPTIM = OptimFactory SCHE = INTERP = InterpFactory diff --git a/src/lumo/trainer/factory.py b/src/lumo/trainer/factory.py index e32cc6d..e64c481 100644 --- a/src/lumo/trainer/factory.py +++ b/src/lumo/trainer/factory.py @@ -8,6 +8,24 @@ class InterpFactory: + """A factory class for creating instances of various interpolation classes. + + This class provides convenient access to various interpolation classes that are defined in the `interp` module. + + Attributes: + Cos (class): An interpolation class for cosine interpolation. + Linear (class): An interpolation class for linear interpolation. + Exp (class): An interpolation class for exponential interpolation. + Log (class): An interpolation class for logarithmic interpolation. + Constant (class): An interpolation class for constant interpolation. + PeriodCos (class): An interpolation class for periodic cosine interpolation. + PeriodHalfCos (class): An interpolation class for periodic half-cosine interpolation. + PeriodTriangle (class): An interpolation class for periodic triangle interpolation. + PeriodLinear (class): An interpolation class for periodic linear interpolation. + PowerDecay (class): An interpolation class for power-decay interpolation. + List (class): An interpolation class for list interpolation. + + """ Cos = interp.Cos Linear = interp.Linear Exp = interp.Exp @@ -47,6 +65,22 @@ def build(self, parameters, optim_cls=None) -> Optimizer: class _OptimFactory: + """ + A factory class that provides different optimization algorithms to be used during training. + + Methods: + create_optim(name=None, **kwargs) -> OptimBuilder: + Creates an instance of OptimBuilder for a specified optimization algorithm. + + Examples: + To create an instance of OptimBuilder for Adam optimizer with default values: + >>> optim_builder = OptimFactory.create_optim(name='Adam') + + To create an instance of OptimBuilder for SGD optimizer with specific values: + >>> optim_builder = OptimFactory.create_optim(name='SGD', lr=0.01, momentum=0.9) + + """ + @overload def create_optim(self, name='SGD', lr=None, momentum=0, dampening=0, weight_decay=0, nesterov=False) -> OptimBuilder: diff --git a/src/lumo/trainer/rnd.py b/src/lumo/trainer/rnd.py index b7509c2..6c685ac 100644 --- a/src/lumo/trainer/rnd.py +++ b/src/lumo/trainer/rnd.py @@ -1,41 +1,29 @@ -import os -import time from typing import Union -from joblib import hash -from lumo.proc.path import cache_dir from lumo.utils import random class RndManager: """ - A seed manager for trainer. Provide interface for `~lumo.utils.random` + A seed manager for the trainer. Provides an interface for `~lumo.utils.random`. """ - def __init__(self): - self.save_dir = os.path.join(cache_dir(), 'rnd') - def mark(self, seed: Union[int, str]): """ - 用于数据集读取一类的,需要特定步骤每一次试验完全相同 - Args: - seed: 该次标记固定种子的名字,第一次调用该方法会在特定目录存放当前状态, - 第二次调用会在该位置读取当前随机种子状态 - - Returns: + Fixes the random seed to a specific state for reproducibility. + Args: + seed (Union[int, str]): The name of the fixed seed state. """ random.fix_seed(random.hashseed(seed)) def shuffle(self, seed=None): """ - 打乱,一般用于复现试验的时候随机一个种子 - Args: - name: - seed: - - Returns: + Shuffles the random seed for reproducibility. + Args: + seed (int, optional): The random seed to use. If None, a random seed based on the current + time will be used. """ if seed is None: random.fix_seed(random.int_time()) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 4fae384..97d2eea 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -24,7 +24,6 @@ from lumo.proc import glob from lumo.trainer.rnd import RndManager from lumo.utils.logger import Logger -from lumo.utils.fmt import strftime from .base import _BaseTrainer from .components import TrainerExperiment, TrainerParams from .saver import Saver From 6d0194298800f59c1a14f1fb85b7357a13c19f16 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 13:51:37 +0800 Subject: [PATCH 34/99] Add docstring powered by chatGPT --- src/lumo/analyse/condition.py | 6 + src/lumo/core/disk.py | 38 +++- src/lumo/core/enums.py | 28 ++- src/lumo/core/interp.py | 123 ++++++++++--- src/lumo/core/meter.py | 227 ++++++++++++++++++++++-- src/lumo/core/params.py | 107 +++++++++++- src/lumo/trainer/trainer.py | 320 ++++++++++++++++++++++++++++++++-- 7 files changed, 788 insertions(+), 61 deletions(-) diff --git a/src/lumo/analyse/condition.py b/src/lumo/analyse/condition.py index 08ce2b2..3f9a229 100644 --- a/src/lumo/analyse/condition.py +++ b/src/lumo/analyse/condition.py @@ -7,14 +7,17 @@ def in_(ser, value): + """pandas operation""" return ser.apply(lambda x: x in value) def not_in_(ser, value): + """pandas operation""" return ser.apply(lambda x: x not in value) def first(ser, value): + """pandas operation""" return ser.duplicated(value) == False @@ -101,16 +104,19 @@ def __repr__(self): return f'{self.name} {self.op} {self.value}' def in_(self, lis): + """condition of `in` operation""" self.op = 'in' self.value = set(lis) return self def not_in_(self, lis): + """condition of `.duplicated(value) == False` operation""" self.op = 'notin' self.value = set(lis) return self def first(self, value): + """condition of `not in` operation""" self.op = 'first' self.value = value return self diff --git a/src/lumo/core/disk.py b/src/lumo/core/disk.py index 46c406a..15b2535 100644 --- a/src/lumo/core/disk.py +++ b/src/lumo/core/disk.py @@ -1,12 +1,31 @@ import os.path + from dbrecord import PList + from lumo.proc import path from lumo.utils import safe_io as IO class Metrics: """ - Record metrics at multiple steps. Supported by dbrecord. + Records metrics at multiple steps and stages. The metrics are supported by dbrecord. + + Args: + test_path (str): The path to the test directory. + + Attributes: + fpath (str): The path to the metric board SQLite file. + disk (PList): The PList instance for accessing the metric board SQLite file. + + Methods: + append(metric: dict, step: int, stage: str = 'train') -> None: + Adds the specified metric, step, and stage to the metric board SQLite file. + The metric is a dictionary object that contains the metric name as the key and the metric value as the value. + The stage is either 'train' or 'test', and it is set to 'train' by default. + This method calls the `flush` method to write the changes to disk. + + flush() -> None: + Writes any changes to the metric board SQLite file to disk. """ def __init__(self, test_path: str): @@ -15,6 +34,17 @@ def __init__(self, test_path: str): self.disk = PList(self.fpath) def append(self, metric: dict, step, stage='train'): + """ + Adds the specified metric, step, and stage to the metric board SQLite file. + + Args: + metric (dict): A dictionary object that contains the metric name as the key and the metric value as the value. + step (int): The step number. + stage (str, optional): The stage of the metric, either 'train' or 'test'. Defaults to 'train'. + + Returns: + None + """ self.disk.append({ 'metric': metric, 'step': step, @@ -23,6 +53,12 @@ def append(self, metric: dict, step, stage='train'): self.disk.flush() def flush(self): + """ + Writes any changes to the metric board SQLite file to disk. + + Returns: + None + """ self.disk.flush() diff --git a/src/lumo/core/enums.py b/src/lumo/core/enums.py index 1997270..aa6dabb 100644 --- a/src/lumo/core/enums.py +++ b/src/lumo/core/enums.py @@ -2,8 +2,7 @@ class TrainStage(enum.Enum): - """ - Enumeration class representing different stages of training. + """An enumeration class representing the different stages of training. """ default = 'default' train = 'train' @@ -11,16 +10,41 @@ class TrainStage(enum.Enum): val = 'val' def is_train(self): + """Check if the current stage is the training stage. + + Returns: + bool: True if the current stage is the training stage, False otherwise. + """ return self.value == 'train' def is_test(self): + """Check if the current stage is the testing stage. + + Returns: + bool: True if the current stage is the testing stage, False otherwise. + """ return self.value == 'test' def is_val(self): + """Check if the current stage is the validation stage. + + Returns: + bool: True if the current stage is the validation stage, False otherwise. + """ return self.value == 'val' @staticmethod def create_from_str(value): + """Create a TrainStage instance from a string. + + If the value is 'eval' or 'evaluate', it will be converted to 'val'. + + Args: + value (str): A string representing the stage of training. + + Returns: + TrainStage: A TrainStage instance representing the stage of training. + """ if value in {'eval', 'evaluate'}: value = 'val' return TrainStage(value) diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py index ad686bc..c9f34c8 100644 --- a/src/lumo/core/interp.py +++ b/src/lumo/core/interp.py @@ -41,37 +41,50 @@ class Interpolate(BaseParams): - """ - ratio 变化为从 1 - 0 - """ + """A class for implementing interpolation schedule of a learning rate.""" @classmethod def interp(self, *args, **kwargs): + """Interpolation method for the schedule. Must be implemented in a subclass.""" raise NotImplementedError() def __repr__(self): + """Return a string representation of the schedule.""" content = ', '.join(["{}={}".format(k, v) for k, v in self.items()]) return "{}({})".format(self.__class__.__name__, content) def __call__(self, cur): + """Return the learning rate at the given step 'cur'.""" raise NotImplementedError() def get(self, key: DictKeyType, default_value: Any = None) -> Any: + """ + Return the value of the key from the schedule's dictionary, or default_value if the key is not present. + + Args: + key: The key of the dictionary. + default_value: The default value to return if the key is not found. + + Returns: + The value of the key, or default_value if the key is not found. + """ return self(key) def plot(self, num=1000, left=0, right=1000, show=True): """ Plot a curve of the schedule. + Args: - num: - left: - right: + num: The number of points to plot the curve. + left: The starting point of the curve. + right: The ending point of the curve. + show: Whether to display the plot or not. Returns: - plt.plot + The plot object. Notes: - You may need to call `plt.show`() to show it. + You may need to call `plt.show()` to show the plot. """ from matplotlib import pyplot as plt @@ -86,21 +99,22 @@ def plot(self, num=1000, left=0, right=1000, show=True): def scale(self, optimizer, cur): """ - Scale the learning rate by current value. 'scale' means not apply the current schedule value directly to - the learning rate, but multiple the initial learning rate. You can use `schedule.apply()` to apply the schedule - value directly. + Scale the learning rate by the current value. + + 'Scale' means that the current schedule value will not be applied directly to the learning rate, but will be + multiplied by the initial learning rate. You can use `schedule.apply()` to apply the schedule value directly. Notes: - ------- - When first apply scale function, a `_raw_lr` represent initial lr will be set in each param_group, then, the - learning rate(store in param_groups with key 'lr') will be calculated by `_raw_lr * schedule(cur)`. + When `scale()` is first called, an initial learning rate `_raw_lr` is stored in each `param_group`. Then, + the learning rate (stored in `param_groups` with the key 'lr') will be calculated as `_raw_lr * + schedule(cur)`. Args: - optimizer: A pytorch optimizer instance. - cur: current step of this schedule. + optimizer: A PyTorch optimizer instance. + cur: The current step of the schedule. Returns: - Current schedule value. + The current schedule value. """ ratio = self(cur) for param_group in optimizer.param_groups: # type:dict @@ -111,14 +125,14 @@ def scale(self, optimizer, cur): def apply(self, optimizer, cur): """ - Apply the learning rate with current schedule value. + Apply the learning rate with the current schedule value. Args: - optimizer: A pytorch optimizer instance. - cur: current step of this schedule. + optimizer: A PyTorch optimizer instance. + cur: The current step of the schedule. Returns: - + The new learning rate. """ new_lr = self(cur) for param_group in optimizer.param_groups: # type:dict @@ -128,14 +142,44 @@ def apply(self, optimizer, cur): class ABCContinuous(Interpolate): + """ + Interpolates a continuous schedule for a value between a start and end point. + + Args: + start (float): The starting value of the schedule. + end (float): The ending value of the schedule. + left (float): The left boundary of the range of values to interpolate over. + right (float): The right boundary of the range of values to interpolate over. + *args: Additional arguments to pass to the superclass constructor. + **kwargs: Additional keyword arguments to pass to the superclass constructor. + + Attributes: + left (float): The left boundary of the range of values to interpolate over. + right (float): The right boundary of the range of values to interpolate over. + start (float): The starting value of the schedule. + end (float): The ending value of the schedule. + constant (bool): A flag indicating whether the schedule is constant. + """ def __call__(self, cur): + """Returns the interpolated value of the schedule at a given point.""" if self.constant: return self.start return self.interp(cur, start=self.start, end=self.end, left=self.left, right=self.right, constant=self.constant) def __init__(self, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs): + """ + Initializes an instance of ABCContinuous. + + Args: + start (float): The starting value of the schedule. + end (float): The ending value of the schedule. + left (float): The left boundary of the range of values to interpolate over. + right (float): The right boundary of the range of values to interpolate over. + *args: Additional arguments to pass to the superclass constructor. + **kwargs: Additional keyword arguments to pass to the superclass constructor. + """ super().__init__() self.left = left self.right = right @@ -145,6 +189,7 @@ def __init__(self, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs): @classmethod def ratio(cls, cur, left, right, constant=False): + """Returns the ratio of a given point between the left and right boundaries.""" if constant: return 0 return (cur - left) / (right - left) @@ -155,6 +200,7 @@ def get_val(cls, cur, start=1e-3, end=1e-6, left=0, right=80, *args, **kwargs): return cls.interp(cur=cur, start=start, end=end, left=left, right=right, *args, **kwargs) def plot(self, num=1000, left=None, right=None, show=True): + """Plots the interpolated schedule.""" if left is None: left = self.left @@ -169,10 +215,30 @@ def plot(self, num=1000, left=None, right=None, show=True): class ABCPeriod(Interpolate): """ - period + A class for generating schedules with a repeating period. + + Attributes: + left (float): The left boundary of the schedule. + period (float): The period of the schedule. + start (float): The start value of the schedule. + end (float): The end value of the schedule. + constant (bool): A flag indicating if the schedule is constant. + """ def __init__(self, start=0, end=1, left=0, period=1, *args, **kwargs): + """ + Initializes an instance of the `ABCPeriod` class. + + Args: + start (float): The start value of the schedule. + end (float): The end value of the schedule. + left (float): The left boundary of the schedule. + period (float): The period of the schedule. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + """ super().__init__() self.left = left self.period = period @@ -182,6 +248,7 @@ def __init__(self, start=0, end=1, left=0, period=1, *args, **kwargs): @classmethod def ratio(cls, cur, left, period, constant=False): + """Returns the ratio of time elapsed in the current period.""" if constant: return 0 if cur < left: @@ -197,6 +264,7 @@ def get_val(cls, cur, start=0, end=1, left=0, right=1, *args, **kwargs): left=left, right=right, *args, **kwargs) def plot(self, num=1000, left=None, n_period=5, show=True): + """Plots the schedule between the specified boundaries.""" if left is None: left = self.left @@ -207,6 +275,7 @@ def plot(self, num=1000, left=None, n_period=5, show=True): return super().plot(num, left, right, show) def __call__(self, cur): + """Returns the current schedule value at the given time.""" return self.interp(cur, start=self.start, end=self.end, left=self.left, period=self.period, constant=self.constant) @@ -216,6 +285,7 @@ class Cos(ABCContinuous): @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) if constant: return start @@ -235,6 +305,7 @@ class Linear(ABCContinuous): @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) if constant: return start @@ -253,6 +324,7 @@ class Exp(ABCContinuous): @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) if constant: return start @@ -274,6 +346,7 @@ class Log(ABCContinuous): @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) if constant: return start @@ -292,6 +365,8 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): class Constant(ABCContinuous): + """A scheduler representing a constant value""" + def __init__(self, value=0.5, *args, **kwargs): super().__init__(start=value, end=value, left=0, right=1, *args, **kwargs) self.constant = True @@ -304,6 +379,7 @@ class PeriodCos(ABCPeriod): @classmethod def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) ratio = cls.ratio(cur, left=left, period=period, constant=constant) @@ -335,6 +411,7 @@ def __init__(self, start=0, end=1, left=0, left_period=1, right_period=1, *args, @classmethod def interp(cls, cur, start=0., end=1., left=0., left_period=0., right_period=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) ratio = cls.ratio(cur, left=left, period=(left_period + right_period), constant=constant) @@ -355,6 +432,7 @@ class PeriodLinear(ABCPeriod): @classmethod def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs): + """Interpolation method for the schedule.""" constant = kwargs.get('constant', False) ratio = cls.ratio(cur, left=left, period=period, constant=constant) return start * (1 - ratio) + end * ratio @@ -388,6 +466,7 @@ def __init__(self, start, schedules, gammas): @classmethod def interp(cls, cur, start=0., gammas=None, schedules=None, *args, **kwargs): + """Interpolation method for the schedule.""" if schedules is None: schedules = [] if gammas is None: diff --git a/src/lumo/core/meter.py b/src/lumo/core/meter.py index 87c63a8..44bc21b 100644 --- a/src/lumo/core/meter.py +++ b/src/lumo/core/meter.py @@ -4,7 +4,7 @@ from collections import OrderedDict from collections.abc import ItemsView from numbers import Number -from typing import Union, Iterator, Tuple, Mapping, Sequence +from typing import Iterator, Tuple, Mapping import numpy as np import torch @@ -13,12 +13,44 @@ class Meter: + """ + A class for recording and managing metrics. + + Attributes: + _prop (dict): A dictionary to store properties of the meter. + _rec (OrderedDict): An ordered dictionary to record the metrics and their values. + _avg (dict): A dictionary to store the aggregation method for each metric. + + Methods: + sorted() -> 'Meter': Returns a new meter with the metrics sorted by their names. + todict() -> OrderedDict: Returns the recorded metrics as an ordered dictionary. + update(dic: Mapping) -> 'Meter': Updates the meter with the given dictionary of metrics. + serialize() -> OrderedDict: Returns a dictionary representation of the meter. + items() -> ItemsView: Returns a view object containing the (metric, value) pairs. + keys() -> KeysView: Returns a view object containing the metric names. + scalar_items() -> Iterator[Tuple[str, Number]]: Returns an iterator over the (metric, value) pairs with scalar values. + + Properties: + sum: Sets the aggregation method to 'sum'. + mean: Sets the aggregation method to 'mean'. + last: Sets the aggregation method to 'last'. + max: Sets the aggregation method to 'max'. + min: Sets the aggregation method to 'min'. + smean: Sets the aggregation method to 'smean'. + """ + def __init__(self): self._prop = {} self._rec = OrderedDict() self._avg = {} def sorted(self) -> 'Meter': + """ + Returns a new meter with the metrics sorted by their names. + + Returns: + A new meter with the metrics sorted by their names. + """ m = Meter() m._prop = self._prop @@ -28,6 +60,12 @@ def sorted(self) -> 'Meter': return m def todict(self): + """ + Returns the recorded metrics as an ordered dictionary. + + Returns: + An ordered dictionary containing the recorded metrics and their values. + """ return self._rec @property @@ -39,18 +77,50 @@ def _stage(self, value): self._prop['stage'] = value def __setattr__(self, key: str, value): + """ + Sets the value of an attribute. + + Args: + key (str): The name of the attribute. + value: The value to set the attribute to. + """ if key.startswith('_'): super(Meter, self).__setattr__(key, value) else: self[key] = value def __getattr__(self, item): + """ + Returns the value of a metric. + + Args: + item: The name of the metric. + + Returns: + The value of the metric. + """ return self[item] def __getitem__(self, item): + """ + Returns the value of a metric. + + Args: + item: The name of the metric. + + Returns: + The value of the metric. + """ return self._rec[item] def __setitem__(self, key, value): + """ + Sets the value of a metric. + + Args: + key: The name of the metric. + value: The value to set the metric to. + """ value = to_ndarray(value) stg = self._avg.get(key, None) @@ -83,66 +153,156 @@ def __setitem__(self, key, value): self._stage = 'default' def __repr__(self): + """ + Returns a string representation of the meter. + + Returns: + A string representation of the meter. + """ return ' | '.join([f'{k}: {v}' for k, v in self._rec.items()]) def __iter__(self): + """ + Returns an iterator over the metric names. + + Returns: + An iterator over the metric names. + """ yield from self.keys() @property def sum(self): + """ + Sets the aggregation method to 'sum'. + + Returns: + The meter itself. + """ self._stage = 'sum' return self @property def mean(self): + """ + Sets the aggregation method to 'mean'. + + Returns: + The meter itself. + """ self._stage = 'mean' return self @property def last(self): + """ + Sets the aggregation method to 'last'. + + Returns: + The meter itself. + """ self._stage = 'last' return self @property def max(self): + """ + Sets the aggregation method to 'max'. + + Returns: + The meter itself. + """ self._stage = 'max' return self @property def min(self): + """ + Sets the aggregation method to 'min'. + + Returns: + The meter itself. + """ self._stage = 'min' return self @property def smean(self): + """ + Sets the aggregation method to 'smean'. + + Returns: + The meter itself. + """ self._stage = 'smean' return self def update(self, dic: Mapping) -> 'Meter': + """ + Updates the meter with the given dictionary of metrics. + + Args: + dic (Mapping): A dictionary containing the metrics and their values. + + Returns: + The meter itself. + """ for k, v in dic.items(): self[str(k)] = v return self def serialize(self) -> OrderedDict: + """ + Returns a dictionary representation of the meter. + + Returns: + An ordered dictionary containing the metrics and their string values. + """ res = OrderedDict() for k, v in self.items(): res[k] = f'{v}' return res def items(self) -> ItemsView: + """ + Returns a view object containing the (metric, value) pairs. + + Returns: + A view object containing the (metric, value) pairs. + """ return self._rec.items() def keys(self): + """ + Returns a view object containing the metric names. + + Returns: + A view object containing the metric names. + """ return self._rec.keys() @staticmethod def from_dict(dic: Mapping): + """ + Returns a new meter with the given dictionary of metrics. + + Args: + dic (Mapping): A dictionary containing the metrics and their values. + + Returns: + A new meter with the given metrics and values. + """ m = Meter() for k, v in dic.items(): m[k] = v return m def scalar_items(self) -> Iterator[Tuple[str, Number]]: + """ + Returns an iterator over the (metric, value) pairs with scalar values. + + Returns: + An iterator over the (metric, value) pairs with scalar values. + """ for k, v in self.items(): nd = to_ndarray(v) if is_scalar(nd): @@ -150,10 +310,46 @@ def scalar_items(self) -> Iterator[Tuple[str, Number]]: class ReduceItem: + """Class that reduces a sequence of values to a single value according to a given method. + + Attributes: + SLIDE_WINDOW_SIZE (int): The size of the sliding window used for averaging (default: 100). + EXP_WEIGHT (float): The exponential weight used for computing the sliding window offset (default: 0.75). + + Args: + item (optional): The initial value (default: None). + gb_method (optional): The reduction method (default: None). Can be one of {'slide', 'mean', 'sum', 'max', 'min', 'last'}. + + Raises: + AssertionError: If the reduction method is 'min' or 'max' and the input is not a scalar. + + Methods: + __repr__(): Returns a string representation of the current reduction value. + __str__(): Returns the same string representation as __repr__(). + update(item): Updates the reduction value with a new item. + res: Returns the current reduction value. + + Examples: + >>> r = ReduceItem(gb_method='mean') + >>> r.update(2) + >>> r.update(3) + >>> r.res + 2.5 + """ SLIDE_WINDOW_SIZE = 100 EXP_WEIGHT = 0.75 def __init__(self, item=None, gb_method=None): + """ + Initializes a new ReduceItem instance. + + Args: + item (optional): The initial value (default: None). + gb_method (optional): The reduction method (default: None). Can be one of {'slide', 'mean', 'sum', 'max', 'min', 'last'}. + + Raises: + AssertionError: If the reduction method is 'min' or 'max' and the input is not a scalar. + """ self.gb_method = gb_method # groupby method self.acc = [] if item is not None: @@ -161,11 +357,11 @@ def __init__(self, item=None, gb_method=None): self.c = len(self.acc) self.cur = item if gb_method == 'max': - self.last = -1e12 + self._last = -1e12 elif gb_method == 'min': - self.last = 1e12 + self._last = 1e12 else: - self.last = 0 + self._last = 0 self._res = self.last @@ -180,10 +376,10 @@ def __init__(self, item=None, gb_method=None): def __repr__(self): """ - simpler but more time-comsuming method could be some math function, not in if-else branch, like - prec = max(min(8, int(np.ceil(np.log10((1 / (self.offset + 1e-10)))))), 1) - fmt_str = f'{{:.{prec}f}}' - return fmt_str.format(res) + Returns a string representation of the current reduction value. + + Returns: + str: The string representation of the current reduction value. """ res = self.res if self.isscalar: @@ -207,6 +403,7 @@ def __repr__(self): __str__ = __repr__ def update(self, item): + """Updates the reduction value with a new item.""" self.cur = item item = detach(item) @@ -218,7 +415,7 @@ def update(self, item): self.acc.append(item) if len(self.acc) > ReduceItem.SLIDE_WINDOW_SIZE: self.acc.pop(0) - self.last = self.cur + self._last = self.cur elif avg in {'mean', 'sum'}: if len(self.acc) == 0: self.acc.append(0) @@ -229,10 +426,20 @@ def update(self, item): elif avg == 'min': self._res = min(self.cur, self._res) - self.last = item + self._last = item + + @property + def last(self): + return self._last @property def res(self): + """ + Returns the current reduction value. + + Returns: + float: The current reduction value. + """ avg = self.gb_method if avg == 'slide': return np.mean(self.acc) diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index a14863a..c4fd4df 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -2,12 +2,12 @@ import os.path import sys import textwrap +from pathlib import Path from pprint import pformat from typing import Any, List, NewType import fire from joblib import hash -from pathlib import Path from omegaconf import DictConfig, OmegaConf, DictKeyType from omegaconf._utils import _ensure_container @@ -57,21 +57,34 @@ def __repr__(self): def _safe_repr(values: Any) -> str: + """Return a formatted string representation of the input values. + + Args: + values: Any type of input values to be formatted. + + Returns: + A string representation of the input values, formatted using `pprint`. + + Raises: + None. + """ return pformat(values) def _padding_mod(st: str, offset=7, mod=4): - """ - 123 \\ - 1 \\ - 12312341 \\ - 1231 + """Pads a string with spaces to a length that is a multiple of a given modulus. + Args: - strs: - mod: + st: The input string to pad. + offset: An integer specifying the minimum length of the output string. If the length of the input string is + less than this value, spaces will be added to the end of the string to make it the desired length. + mod: An integer specifying the modulus. The length of the output string will be a multiple of this value. Returns: - + A string that is a multiple of the given modulus and has a length of at least `offset`. If the length of the + input string is less than `offset`, the output string will be padded with spaces to achieve the minimum length. + If the length of the input string is already a multiple of the given modulus, the output string will have the + same length as the input string. """ size = len(st) if size < offset: @@ -104,11 +117,29 @@ def safe_param_repr(values: List[tuple], level=1) -> str: class BaseParams(DictConfig): + """ + A dictionary-like configuration object that supports parameter constraint validation. + """ + def __init__(self): + """ + Initializes a new instance of the BaseParams class. + """ super().__init__({}, flags={'no_deepcopy_set_nodes': True}) self.__dict__["_prop"] = {} def __setattr__(self, key: str, value: Any) -> None: + """ + Sets an attribute value for the specified key. + + Args: + key (str): The key of the attribute. + value (Any): The value of the attribute. + + Raises: + BoundCheckError: If the specified value is not within the specified bounds or choices. + + """ if key != '_prop': if isinstance(value, (Arange, Choices)): res = self._prop.get('constrain', {}) @@ -122,6 +153,17 @@ def __setattr__(self, key: str, value: Any) -> None: super().__setattr__(key, value) def __setitem__(self, key: DictKeyType, value: Any) -> None: + """ + Sets a dictionary item value for the specified key. + + Args: + key (DictKeyType): The key of the item. + value (Any): The value of the item. + + Raises: + BoundCheckError: If the specified value is not within the specified bounds or choices. + + """ if key != '_prop': if isinstance(value, (Arange, Choices)): self._prop.setdefault('constrain', {})[key] = value @@ -133,10 +175,31 @@ def __setitem__(self, key: DictKeyType, value: Any) -> None: super().__setitem__(key, value) def __getattr__(self, key: str) -> Any: + """ + Gets an attribute value for the specified key. + + Args: + key (str): The key of the attribute. + + Returns: + Any: The value of the attribute. + + """ res = super().__getattr__(key) return res def _check(self, name, value): + """ + Checks if the specified parameter value is within the specified bounds or choices. + + Args: + name (str): The name of the parameter. + value (Any): The value of the parameter. + + Raises: + BoundCheckError: If the specified value is not within the specified bounds or choices. + + """ bound = self._prop['constrain'][name] if isinstance(bound, Arange) and not (bound.left <= value and value <= bound.right): raise BoundCheckError( @@ -145,10 +208,29 @@ def _check(self, name, value): raise BoundCheckError(f"value of param '{name}' should in values {bound.choices}, but got {value}") def __getitem__(self, key: DictKeyType) -> Any: + """ + Gets a dictionary item value for the specified key. + + Args: + key (DictKeyType): The key of the item. + + Returns: + Any: The value of the item. + + """ return super().__getitem__(key) def __repr__(self): + """ + Returns a string representation of the BaseParams object. + + Returns: + str: A string representation of the BaseParams object. + + """ + def _arg_to_str(k, v): + """to str""" res = self._prop.get('constrain', {}).get(k, None) if res is not None: return f'{res}, {type(v).__name__}' @@ -167,6 +249,13 @@ def _arg_to_str(k, v): return "{}.Space".format(self.__class__.__name__) + '(\n' + args_str + '\n)' def copy(self): + """ + Returns a copy of the BaseParams object. + + Returns: + BaseParams: A copy of the BaseParams object. + + """ copied = self.__class__() copied.from_dict(super(BaseParams, self).copy()) return copied diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 97d2eea..5663a3e 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -256,76 +256,167 @@ def eidx(self): @property def global_steps(self) -> int: - # started from 0 + """started from 0""" return self._prop['global_steps'] @property - def trainer_state(self): + def trainer_state(self) -> Any: + """ + Get the state of the Trainer object. + + Returns: + Any: The state of the Trainer object. + """ return self._prop @property def devices(self) -> Dict[str, torch.device]: - # return self._state_dicts['devices'] + """ + Get the dictionary of devices used in the training session. + + Returns: + Dict[str, torch.device]: A dictionary containing the devices used in the training session. + """ return {key: self[key] for key in self._state_dicts['devices']} @property def model_dict(self) -> Dict[str, nn.Module]: - return {key: self[key] - for key in self._state_dicts['models']} + """ + Get the dictionary of model objects used in the training session. + + Returns: + Dict[str, nn.Module]: A dictionary containing the model objects used in the training session. + """ + return {key: self[key] for key in self._state_dicts['models']} @property def optim_dict(self) -> Dict[str, Optimizer]: + """ + Get the dictionary of optimizer objects used in the training session. + + Returns: + Dict[str, Optimizer]: A dictionary containing the optimizer objects used in the training session. + """ return {key: self[key] for key in self._state_dicts['optims']} @property def torch_tensor(self) -> Dict[str, torch.Tensor]: + """ + Get the dictionary of PyTorch tensor objects used in the training session. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the PyTorch tensor objects used in the training session. + """ return {key: self[key] for key in self._state_dicts['tensor.th']} @property def numpy_tensor(self) -> Dict[str, np.ndarray]: + """ + Get the dictionary of NumPy array objects used in the training session. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the NumPy array objects used in the training session. + """ return {key: self[key] for key in self._state_dicts['tensor.np']} @property def others(self) -> Dict[str, Any]: + """ + A dictionary of additional attributes stored in the Trainer. + + Returns: + Dict[str, Any]: The dictionary of additional attributes. + """ return {key: self[key] for key in self._state_dicts['others']} @property def datamodule(self) -> DataModule: + """ + Returns the DataModule associated with this Trainer. + + Returns: + DataModule: The DataModule associated with this Trainer. + """ return self.dm @property def train_dataloader(self) -> Optional[DataLoaderType]: + """ + Returns the DataLoader for the training data. + + Returns: + Optional[DataLoaderType]: The DataLoader for the training data, or None if it is not available. + """ return self.datamodule['train'] @property def test_dataloader(self) -> Optional[DataLoaderType]: + """ + Returns the DataLoader for the test data. + + Returns: + Optional[DataLoaderType]: The DataLoader for the test data, or None if it is not available. + """ return self.datamodule['test'] @property def val_dataloader(self) -> Optional[DataLoaderType]: + """ + Returns the DataLoader for the validation data. + + Returns: + Optional[DataLoaderType]: The DataLoader for the validation data, or None if it is not available. + """ return self.datamodule['val'] @property def device(self): + """ + Returns the device used for training. + + Returns: + The device used for training. + """ return self.accelerate.device def _load_fun_state_dict(self, src: dict): + """ + Loads state dicts into the Trainer's attributes. + + Args: + src (dict): A dictionary of state dicts to be loaded. + """ for k, v in src.items(): if self._rev_index.get(k, None) is not None: self[k].load_state_dict(v) - # if k in src: - # v.load_state_dict(src[k]) def regist_dataloader(self, dataloader: DataLoader, stage: TrainStage): + """ + Registers a dataloader with a given training stage to the current datamodule. + + Args: + dataloader (DataLoader): The dataloader to be registered. + stage (TrainStage): The training stage to which the dataloader will be associated. + + Returns: + None + """ self.datamodule.regist_dataloader_with_stage(stage, dataloader) def process_loader(self, dm: Union[DataModule, DataLoader] = None, stage: TrainStage = TrainStage.train): """ - automatically called before train()/test()/evaluate(), see __new__ function of Trainer - :param dm: - :param stage: - :return: + Prepares and registers a dataloader with the given training stage to the current datamodule. + + Args: + dm (Union[DataModule, DataLoader], optional): The datamodule or dataloader to be processed. If not provided, + the current datamodule will be used if it exists. + stage (TrainStage, optional): The training stage to which the dataloader will be associated. Defaults to TrainStage.train. + + Returns: + DataLoader: The prepared and registered dataloader. + None: If the dataloader cannot be prepared or registered. """ + assert stage is not None, '`stage` cannot be None' if dm is None and self.dm is not None: dm = self.dm @@ -349,6 +440,19 @@ def process_loader(self, dm: Union[DataModule, DataLoader] = None, stage: TrainS return loader def save_state_dict(self, name='latest.pth', dirpath=None, only_main=True): + """ + Saves the current state dictionary to a file. + + Args: + name: The name of the file to save the state dictionary to. Defaults to 'latest.pth'. + dirpath: The directory path to save the state dictionary file to. If None, defaults to the state dictionary + directory of the Trainer's experiment. + only_main: If True, saves the state dictionary to a single file. If False and the Trainer is distributed, + saves the state dictionary to multiple files, one for each process. + + Returns: + The path to the saved state dictionary file. + """ if not only_main and self.is_dist: pre, ext = os.path.splitext(name) name = f'{pre}-{self.local_rank}{ext}' @@ -361,6 +465,13 @@ def save_state_dict(self, name='latest.pth', dirpath=None, only_main=True): return fn def load_state_dict(self, state_dict: dict): + """Load state dictionary from a given dictionary. + Args: + state_dict (dict): A dictionary containing the state dictionary to be loaded. + + Returns: + None + """ _sub = {'models', 'optims', 'other'} _missing = [] @@ -374,7 +485,7 @@ def load_state_dict(self, state_dict: dict): def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapping]] = None, device: torch.device = None): - + """Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.""" if item is None: for k, v in list(self.model_dict.items()): self[k] = self.accelerate.prepare(v) @@ -395,11 +506,20 @@ def on_trainer_exception(self, func: Callable, exception: BaseException): @property def is_initialized(self): + """Whether this Trainer is initialized.""" if self._prop.get('initial', False): return True return False def initialize(self): + """ + Initializes the Trainer object, update meta information in Experiment and TableRow. + + If the Trainer object is already initialized, this method does nothing. + + This function is auto called when start train()/test()/evaluate() + """ + if self.is_initialized: return self.exp.start() @@ -427,10 +547,12 @@ def initialize(self): self.set_property('initial', True) def stop_train(self): + """Toggle to stop train.""" self.train_toggle = True self.train_epoch_toggle = True def stop_train_epoch(self): + """Toggle to skip current train epoch.""" self.train_epoch_toggle = True def prepare_dataloader(self, loader: DataLoaderType, stage: TrainStage = None): @@ -458,6 +580,21 @@ def prepare_dataloader(self, loader: DataLoaderType, stage: TrainStage = None): return loader def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType = None, limit_global_steps=None): + """Trains the model using the specified data loader and parameters. + + Args: + dm (Union[DataModule, DataLoaderType], optional): The data loader or data module to use for training. + Defaults to self.train_dataloader. + params (ParamsType, optional): The training parameters to use. Defaults to None. + limit_global_steps (int, optional): The maximum number of global steps to train for. Defaults to None. + + Returns: + Dict[str, Any]: A dictionary of training results. + + Raises: + ValueError: If no data loader is available for training. + + """ loader = self.select_loader(dm) if not loader: loader = self.train_dataloader @@ -499,6 +636,18 @@ def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, limit_step=None, limit_global_steps=None) -> Record: + """Trains the model for one epoch using the specified data loader and parameters. + + Args: + loader (DataLoaderType): The data loader to use for training. + params (ParamsType, optional): The training parameters to use. Defaults to None. + limit_step (int, optional): The maximum number of steps to train for. Defaults to None. + limit_global_steps (int, optional): The maximum number of global steps to train for. Defaults to None. + + Returns: + Record: A record of training results for the epoch. + + """ stage = TrainStage.train self.change_stage(stage) record = self.create_record(stage=stage) @@ -528,16 +677,53 @@ def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, self.database.update_dict(dict(eidx=self.eidx, end=datetime.now())) return record - def set_property(self, key, value): + def set_property(self, key: str, value: any) -> None: + """ + Sets a property with the given key to the given value. + + Args: + key: A string representing the name of the property. + value: The value to assign to the property. + + Returns: + None + """ self._prop[key] = value - def set_global_steps(self, val): + def set_global_steps(self, val: int) -> None: + """ + Sets the global step count to the given value. + + Args: + val: An integer representing the global step count. + + Returns: + None + """ self.set_property('global_steps', val) - def set_epoch_idx(self, val): + def set_epoch_idx(self, val: int) -> None: + """ + Sets the current epoch index to the given value. + + Args: + val: An integer representing the current epoch index. + + Returns: + None + """ self.set_property('eidx', val) - def set_idx(self, val): + def set_idx(self, val: int) -> None: + """ + Sets the current index to the given value. + + Args: + val: An integer representing the current index. + + Returns: + None + """ self.set_property('idx', val) @property @@ -610,6 +796,16 @@ def select_loader(cls, dm=None): return loader def test(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = None, limit_step=None): + """ + Tests the model on a given dataset and returns a `Record` object containing the evaluation results. + + Args: + dm (Union[DataModule, DataLoader], optional): A `DataModule` or `DataLoader` object for the dataset to test on. + params (ParamsType, optional): A dictionary containing hyperparameters for the test. + limit_step (int, optional): An integer specifying the maximum number of batches to test on. + Returns: + A `Record` object containing the evaluation results. + """ stage = TrainStage.test self.change_stage(stage) @@ -637,6 +833,15 @@ def test(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = No return record def evaluate(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType = None, limit_step: int = None): + """ + Evaluates the model on a given dataset and returns a `Record` object containing the evaluation results. + Args: + dm (Union[DataModule, DataLoader], optional): A `DataModule` or `DataLoader` object for the dataset to evaluate on. + params (ParamsType, optional): A dictionary containing hyperparameters for the evaluation. + limit_step (int, optional): An integer specifying the maximum number of batches to evaluate on. + Returns: + A `Record` object containing the evaluation results. + """ stage = TrainStage.val self.change_stage(stage) @@ -662,45 +867,77 @@ def evaluate(self, dm: Union[DataModule, DataLoader] = None, params: ParamsType return record def train_step(self, batch, params: ParamsType = None) -> MetricType: + """ + Runs a single training step on a batch of data and returns a dictionary of training metrics. + Args: + batch: A batch of data to train on. + params (ParamsType, optional): A dictionary containing hyperparameters for the training step. + Returns: + A dictionary of training metrics. + """ pass def test_step(self, batch, params: ParamsType = None) -> MetricType: + """ + Runs a single testing step on a batch of data and returns a dictionary of evaluation metrics. + Args: + batch: A batch of data to test on. + params (ParamsType, optional): A dictionary containing hyperparameters for the testing step. + Returns: + A dictionary of evaluation metrics. + """ pass def evaluate_step(self, batch, params: ParamsType = None) -> MetricType: + """ + Runs a single evaluation step on a batch of data and returns a dictionary of evaluation metrics. + Args: + batch: A batch of data to evaluate on. + params (ParamsType, optional): A dictionary containing hyperparameters for the evaluation step. + Returns: + A dictionary of evaluation metrics. + """ pass def imodels(self, params: ParamsType): + """Initialize model in here""" pass def icallbacks(self, params: ParamsType): + """Initialize callbacks in here""" pass def inference(self, batch): + """Perform inference on a batch of data.""" raise NotImplementedError() def predict(self, batch): + """Make a prediction on a batch of data.""" raise NotImplementedError() def optim_state_dict(self, wrap=True): + """Get a dictionary of the state of the optimizers.""" res = {k: v.state_dict() for k, v in self.optim_dict.items()} if wrap: res = {'optim': res} return res def model_state_dict(self, wrap=True): + """Get a dictionary of the state of the models.""" res = {k: self.accelerate.unwrap_model(v).state_dict() for k, v in self.model_dict.items()} if wrap: res = {'model': res} return res def other_state_dict(self, wrap=True): + """Get a dictionary of the state of the other objects.""" res = {k: v.state_dict() for k, v in self.others.items()} if wrap: res = {'other': res} return res def state_dict(self): + """Get a dictionary of the state of the object.""" res = { 'optims': self.optim_state_dict(wrap=False), 'models': self.model_state_dict(wrap=False), @@ -713,9 +950,25 @@ def state_dict(self): return res def Meter(self): + """ + Returns a new instance of the Meter class. + + Returns: + Meter: A new instance of the Meter class. + """ return Meter() def create_record(self, stage: TrainStage = None): + """ + Creates a new Record object with the specified TrainStage. + + Args: + stage (TrainStage, optional): The TrainStage to use for the new Record object. If not provided, the TrainStage + from the Trainer object will be used. + + Returns: + Record: A new Record object with the specified TrainStage. + """ if stage is None: stage = self.trainstage record = Record(stage=stage) @@ -723,11 +976,22 @@ def create_record(self, stage: TrainStage = None): def wait_for_everyone(self): """ - making sure all processes have reached this point before continuing. + Will stop the execution of the current process until every other process has reached that point """ self.accelerate.wait_for_everyone() def save_model(self, is_best=False, meta_info: Union[str, dict] = None): + """ + Saves the current model. + + Args: + is_best (bool, optional): Indicates whether the current model is the best one so far. Defaults to False. + meta_info (Union[str, dict], optional): Additional information to include in the saved model file. Can be a + string, a dictionary, or a Meter object. Defaults to None. + + Returns: + str: The path to the saved model file. + """ info = self._build_trainer_meta_info(meta_info) val = self.saver.save_model(self.eidx, self.model_state_dict(), meta_info=info, @@ -736,6 +1000,16 @@ def save_model(self, is_best=False, meta_info: Union[str, dict] = None): return val def _build_trainer_meta_info(self, meta_info: Union[str, dict] = None): + """ + Builds a dictionary containing metadata about the Trainer object. + + Args: + meta_info (Union[str, dict], optional): Additional metadata to include in the dictionary. Can be a string, a + dictionary, or a Meter object. Defaults to None. + + Returns: + dict: A dictionary containing metadata about the Trainer object. + """ info = dict() info['eidx'] = self.eidx if meta_info is not None: @@ -748,6 +1022,18 @@ def _build_trainer_meta_info(self, meta_info: Union[str, dict] = None): return info def save_checkpoint(self, max_keep=10, is_best=False, meta_info: Union[str, dict, Meter] = None): + """ + Saves a checkpoint of the current state of the Trainer object. + + Args: + max_keep (int, optional): The maximum number of checkpoints to keep. Defaults to 10. + is_best (bool, optional): Indicates whether the current checkpoint is the best one so far. Defaults to False. + meta_info (Union[str, dict, Meter], optional): Additional information to include in the saved checkpoint file. + Can be a string, a dictionary, or a Meter object. Defaults to None. + + Returns: + str: The path to the saved checkpoint file. + """ info = self._build_trainer_meta_info(meta_info) val = self.saver.save_checkpoint(self.eidx, self.state_dict(), meta_info=info, From fbda3575cf7bdd5132dd36c916336be0245dd3c6 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 15:49:34 +0800 Subject: [PATCH 35/99] Add doc coverage test --- .docstr.yaml | 23 +++++++++++++++++++++++ src/lumo/vis/__init__.py | 0 2 files changed, 23 insertions(+) create mode 100644 .docstr.yaml delete mode 100644 src/lumo/vis/__init__.py diff --git a/.docstr.yaml b/.docstr.yaml new file mode 100644 index 0000000..1108d91 --- /dev/null +++ b/.docstr.yaml @@ -0,0 +1,23 @@ +paths: ./src/lumo/ +badge: ./images # Path +exclude: + +verbose: 3 # int (0-4) +skip_magic: True # Boolean +skip_file_doc: True # Boolean +skip_property: True +skip_init: True # Boolean +#skip_class_def: True # Boolean +skip_private: True # Boolean +follow_links: True # Boolean +accept_empty: True # Boolean +ignore_names_file: .*/test # regex +#fail_under: 90 # int +#percentage_only: True # Boolean +ignore_patterns: + .*: + - "on.*end" + - "on.*begin" + - "on.*begin" + - "on_first_exception" + - "on_hook.*" \ No newline at end of file diff --git a/src/lumo/vis/__init__.py b/src/lumo/vis/__init__.py deleted file mode 100644 index e69de29..0000000 From 7ca1a93e5316283dcbab3a0960c5b3bdf3ddb3f3 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 15:49:49 +0800 Subject: [PATCH 36/99] Doc coverage rate up to 65.1% --- src/lumo/sketch/vis/__init__.py | 0 src/lumo/{ => sketch}/vis/__main__.py | 0 src/lumo/{ => sketch}/vis/parser.py | 0 src/lumo/{ => sketch}/vis/parser_tb.py | 0 src/lumo/trainer/factory.py | 55 +++++-- src/lumo/trainer/saver.py | 105 ++++++++++++++ src/lumo/trainer/trainer.py | 39 ++++- src/lumo/utils/exithook.py | 4 + src/lumo/utils/logger.py | 51 +++++++ src/lumo/utils/memory_grab.py | 191 +++++++++++++++++++++---- src/lumo/utils/repository.py | 59 +++++--- src/lumo/utils/safe_io.py | 56 +++++++- src/lumo/utils/screen.py | 46 +++++- 13 files changed, 541 insertions(+), 65 deletions(-) create mode 100644 src/lumo/sketch/vis/__init__.py rename src/lumo/{ => sketch}/vis/__main__.py (100%) rename src/lumo/{ => sketch}/vis/parser.py (100%) rename src/lumo/{ => sketch}/vis/parser_tb.py (100%) diff --git a/src/lumo/sketch/vis/__init__.py b/src/lumo/sketch/vis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lumo/vis/__main__.py b/src/lumo/sketch/vis/__main__.py similarity index 100% rename from src/lumo/vis/__main__.py rename to src/lumo/sketch/vis/__main__.py diff --git a/src/lumo/vis/parser.py b/src/lumo/sketch/vis/parser.py similarity index 100% rename from src/lumo/vis/parser.py rename to src/lumo/sketch/vis/parser.py diff --git a/src/lumo/vis/parser_tb.py b/src/lumo/sketch/vis/parser_tb.py similarity index 100% rename from src/lumo/vis/parser_tb.py rename to src/lumo/sketch/vis/parser_tb.py diff --git a/src/lumo/trainer/factory.py b/src/lumo/trainer/factory.py index e64c481..939f057 100644 --- a/src/lumo/trainer/factory.py +++ b/src/lumo/trainer/factory.py @@ -40,14 +40,44 @@ class InterpFactory: class OptimBuilder(BaseParams): + """A class for building an optimizer with specified parameters. + + Attributes: + None + + Methods: + from_kwargs(cls, **kwargs): Creates a new instance of OptimBuilder class and updates its attributes with the given keyword arguments. + build(self, parameters, optim_cls=None) -> Optimizer: Builds and returns an optimizer with the specified parameters. + + """ @classmethod def from_kwargs(cls, **kwargs): + """Creates a new instance of OptimBuilder class and updates its attributes with the given keyword arguments. + + Args: + **kwargs: A dictionary containing the optimizer parameters. + + Returns: + self: A new instance of the OptimBuilder class with updated attributes. + """ self = cls() self.update(kwargs) return self def build(self, parameters, optim_cls=None) -> Optimizer: + """Builds and returns an optimizer with the specified parameters. + + Args: + parameters: The parameters for the optimizer. + optim_cls: The class of the optimizer to be built. + + Returns: + optim_cls: The built optimizer. + + Raises: + ModuleNotFoundError: If the specified optimizer class cannot be found in the corresponding module. + """ res = self.copy() name = res['name'] lname = name.lower() @@ -60,6 +90,8 @@ def build(self, parameters, optim_cls=None) -> Optimizer: else: optim_lib = importlib.import_module("torch.optim.{}".format(lname)) optim_cls = getattr(optim_lib, name, None) + if optim_cls is None: + raise ModuleNotFoundError("Cannot find {} in {}".format(name, optim_lib)) return optim_cls(parameters, **args) @@ -84,56 +116,57 @@ class _OptimFactory: @overload def create_optim(self, name='SGD', lr=None, momentum=0, dampening=0, weight_decay=0, nesterov=False) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the SGD optimizer.""" @overload def create_optim(self, name='Adam', lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the Adam optimizer.""" @overload def create_optim(self, name='Adadelta', lr=1.0, rho=0.9, eps=1e-6, weight_decay=0) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the Adadelta optimizer.""" @overload def create_optim(self, name='Adagrad', lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the Adagrad optimizer.""" @overload def create_optim(self, name='AdamW', lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the AdamW optimizer.""" @overload def create_optim(self, name='AdamW', lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the AdamW optimizer.""" @overload def create_optim(self, name='ASGD', lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the ASGD optimizer.""" @overload def create_optim(self, name='LBFGS', lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=100, line_search_fn=None) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the LBFGS optimizer.""" @overload def create_optim(self, name='RMSprop', lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the RMSprop optimizer.""" @overload def create_optim(self, name='Rprop', lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the Rprop optimizer.""" @overload def create_optim(self, name='SparseAdam', lr=1e-3, betas=(0.9, 0.999), eps=1e-8) -> OptimBuilder: - pass + """Creates an instance of OptimBuilder for the SparseAdam optimizer.""" def create_optim(self, name=None, **kwargs) -> OptimBuilder: + """Create.""" return OptimBuilder.from_kwargs(name=name, **kwargs) diff --git a/src/lumo/trainer/saver.py b/src/lumo/trainer/saver.py index 3a62cff..1088e03 100644 --- a/src/lumo/trainer/saver.py +++ b/src/lumo/trainer/saver.py @@ -9,11 +9,61 @@ # state_dict_tuple = namedtuple('state_dict_tuplt', ['state_dict', 'meta_info'], defaults=[None]) class state_dict_tuple: + """ + A class that stores a state dictionary and corresponding meta information. + + Args: + state_dict (dict, optional): The state dictionary to be stored. Defaults to None. + meta_info (Any, optional): Any meta information to be stored. Defaults to None. + + Attributes: + state_dict (dict): The stored state dictionary. + meta_info (Any): The stored meta information. + + Returns: + A state_dict_tuple instance. + + Raises: + IndexError: If the index is not 0 or 1. + + Examples: + # Create an instance of state_dict_tuple with state_dict and meta_info + >>> sd = state_dict_tuple({'a': 1, 'b': 2}, 'meta') + + # Access the state_dict and meta_info using the [] operator + >>> sd[0] + {'a': 1, 'b': 2} + >>> sd[1] + 'meta' + """ + def __init__(self, state_dict=None, meta_info=None): self.state_dict = state_dict self.meta_info = meta_info def __getitem__(self, item): + """ + Get the stored state dictionary or meta information. + + Args: + item (int): Index of the desired item. + + Returns: + The state dictionary (if item is 0) or the meta information (if item is 1). + + Raises: + IndexError: If the index is not 0 or 1. + + Examples: + # Access the state_dict and meta_info using the [] operator + >>> sd = state_dict_tuple({'a': 1, 'b': 2}, 'meta') + >>> sd[0] + {'a': 1, 'b': 2} + >>> sd[1] + 'meta' + >>> sd[2] # Raises IndexError + IndexError: 2 + """ if item == 0: return self.state_dict elif item == 1: @@ -226,6 +276,20 @@ def save_model(self, step: int, state_dict, meta_info: Union[str, dict] = None, return None def load_checkpoint(self, index=-1, best_if_exist=False, fn=None, with_meta=False, map_location='cpu'): + """ + Loads a checkpoint file. + + Args: + index (int, optional): Index of the checkpoint file in the list of checkpoints. Defaults to -1. + best_if_exist (bool, optional): If True, load the best checkpoint file. Defaults to False. + fn (str, optional): The filename of the checkpoint file. Defaults to None. + with_meta (bool, optional): If True, the checkpoint file is expected to contain metadata. Defaults to False. + map_location (str, optional): Where to load the checkpoint file. Defaults to 'cpu'. + + Returns: + Union[None, Any]: None if the checkpoint file could not be loaded, otherwise the loaded checkpoint. + + """ if fn is None and best_if_exist: fn = self.best_checkpoint() if fn is None: @@ -238,6 +302,19 @@ def load_checkpoint(self, index=-1, best_if_exist=False, fn=None, with_meta=Fals return None def load_keypoint(self, index=-1, fn=None, with_meta=False, map_location='cpu'): + """ + Loads a checkpoint file that is key. + + Args: + index (int, optional): Index of the keypoint file in the list of keypoints. Defaults to -1. + fn (str, optional): The filename of the keypoint file. Defaults to None. + with_meta (bool, optional): If True, the keypoint file is expected to contain metadata. Defaults to False. + map_location (str, optional): Where to load the keypoint file. Defaults to 'cpu'. + + Returns: + Union[None, Any]: None if the keypoint file could not be loaded, otherwise the loaded keypoint. + + """ if fn is None: try: fn = self.list_keypoints()[index] @@ -248,6 +325,20 @@ def load_keypoint(self, index=-1, fn=None, with_meta=False, map_location='cpu'): return None def load_model(self, index=-1, best_if_exist=False, fn=None, with_meta=False, map_location='cpu'): + """ + Loads a model file. + + Args: + index (int, optional): Index of the model file in the list of models. Defaults to -1. + best_if_exist (bool, optional): If True, load the best model file. Defaults to False. + fn (str, optional): The filename of the model file. Defaults to None. + with_meta (bool, optional): If True, the model file is expected to contain metadata. Defaults to False. + map_location (str, optional): Where to load the model file. Defaults to 'cpu'. + + Returns: + Union[None, Any]: None if the model file could not be loaded, otherwise the loaded model. + + """ if fn is None and best_if_exist: fn = self.best_model() if fn is None: @@ -260,38 +351,52 @@ def load_model(self, index=-1, best_if_exist=False, fn=None, with_meta=False, ma return None def best_checkpoint(self): + """ + Returns the filename of the best checkpoint file if it exists, otherwise None. + + Returns: + Union[None, str]: The filename of the best checkpoint file if it exists, otherwise None. + + """ fn = os.path.join(self.save_dir, 'best.checkpoint.pt') if os.path.exists(fn): return fn return None def best_model(self): + """Returns the filename of the best model file if it exists, otherwise None.""" fn = os.path.join(self.save_dir, 'best.model.pt') if os.path.exists(fn): return fn return None def _is_pkl(self, x: str, start, end): + """Helper function that returns True if the string x starts with start and ends with end.""" return x.startswith(start) and x.endswith(end) def list_checkpoints(self) -> List[str]: + """Returns a sorted list of filenames of all checkpoint files in the save directory.""" return sorted(list(filter(lambda x: self._is_pkl(x, 'checkpoints', 'pt'), os.listdir(self.save_dir))), key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime) def list_keypoints(self) -> List[str]: + """Returns a sorted list of filenames of all keypoint files in the save directory.""" return sorted(list(filter(lambda x: self._is_pkl(x, 'key', 'pt'), os.listdir(self.save_dir))), key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime) def list_models(self) -> List[str]: + """Returns a sorted list of filenames of all model files in the save directory.""" return sorted(list(filter(lambda x: self._is_pkl(x, 'model', 'pt'), os.listdir(self.save_dir))), key=lambda x: os.stat(os.path.join(self.save_dir, x)).st_ctime) def list(self): + """Returns a sorted list of all filenames in the save directory.""" return sorted(os.listdir(self.save_dir)) def summary(self): + """Prints a summary of the contents of the save directory.""" print(f"saved in : {self.save_dir}") print(textwrap.indent('\n'.join(self.list()), ' - ')) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 5663a3e..63c0662 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -220,6 +220,7 @@ def writer(self): res = SummaryWriter(**kwargs) def close(*args): + """close writer""" res.flush() res.close() @@ -498,6 +499,7 @@ def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapp return item def on_trainer_exception(self, func: Callable, exception: BaseException): + """Updates database with error information when an exception occurs during training.""" self.database.update_dict(dict(end=datetime.now(), finished=False, error=str(exception), @@ -731,13 +733,23 @@ def trainstage(self) -> TrainStage: return self._prop.get('stage', TrainStage.default) def set_stage(self, val: TrainStage): + """ + Sets the training stage to the given value. + + Args: + val (TrainStage): The value to set the training stage to. + """ self.set_property('stage', val) def add_callback(self, callback): """ - 添加一个回调函数,注意,不能添加重复的 callback,这不推荐,也没有必要。 - :param callback: - :return: + Adds a callback function. Note that duplicate callbacks are not recommended and not necessary. + + Args: + callback: The callback function to add. + + Returns: + bool: True if the callback was added successfully, False otherwise. """ msg = None cb_name = callback.__class__.__name__ @@ -767,10 +779,22 @@ def add_callback(self, callback): return True def remove_callback(self, cur): + """ + Removes the given callback from the list of callbacks. + + Args: + cur: The callback to remove. + """ self.callbacks.remove(cur) pass def change_stage(self, stage: TrainStage): + """ + Changes the training stage to the given value. + + Args: + stage (TrainStage): The value to change the training stage to. + """ if self.trainstage == stage: return @@ -785,6 +809,15 @@ def change_stage(self, stage: TrainStage): @classmethod def select_loader(cls, dm=None): + """ + Selects the appropriate loader based on the given data module. + + Args: + dm (DataModule or DataLoader or DataLoaderSide, optional): The data module to use. Defaults to None. + + Returns: + DataLoader or None: The appropriate loader based on the given data module, or None if dm is None. + """ loader = None if dm: if isinstance(dm, DataModule): diff --git a/src/lumo/utils/exithook.py b/src/lumo/utils/exithook.py index 2c906df..4f6f29c 100644 --- a/src/lumo/utils/exithook.py +++ b/src/lumo/utils/exithook.py @@ -15,8 +15,10 @@ def wrap_after(func): old = sys.excepthook def outer(fun): + """wrap function""" @wraps(fun) def inner(*args, **kwargs): + """wrap function""" old(*args, **kwargs) fun(*args, **kwargs) @@ -30,8 +32,10 @@ def wrap_before(func): old = sys.excepthook def outer(fun): + """wrap function""" @wraps(fun) def inner(*args, **kwargs): + """wrap function""" fun(*args, **kwargs) old(*args, **kwargs) diff --git a/src/lumo/utils/logger.py b/src/lumo/utils/logger.py index f590c35..04c3e4f 100644 --- a/src/lumo/utils/logger.py +++ b/src/lumo/utils/logger.py @@ -29,6 +29,12 @@ def _get_print_func(): + """ + Returns the `rich.print` function if available, or the built-in `print` function if not. + + Returns: + callable: The `rich.print` function if available, or the built-in `print` function if not. + """ try: from rich import print except ImportError: @@ -37,6 +43,12 @@ def _get_print_func(): def get_global_logger(): + """ + Returns the global logger object, creating it if it does not exist. + + Returns: + logging.Logger: The global logger object. + """ global logger if logger is None: logger = Logger() @@ -44,18 +56,39 @@ def get_global_logger(): def set_global_logger(logger_): + """ + Sets the global logger object to the specified logger instance. + + Args: + logger_ (logging.Logger, optional): The logger instance to set as the global logger. If `None`, + a new logger instance will be created. Defaults to `None`. + + Returns: + logging.Logger: The global logger object. + """ global logger logger = logger_ return logger def process_str(): + """ + Returns a string representing the current process, suitable for use in log messages or other output. + + Returns: + str: A string representing the current process. If running in a distributed context (as determined + by `is_dist()`), the string will be of the form '[local_rank]', where 'local_rank' is the local + rank of the current process. Otherwise, an empty string will be returned. + """ if is_dist(): return f'[{local_rank()}]' return '' class Logger: + """A logger adapted for deep learning training experiments that can print output without a newline and + adds process index in distributed training. + """ VERBOSE = 20 VVV_DEBUG = VVV_DEBUG VV_DEBUG = VV_DEBUG @@ -230,6 +263,24 @@ def print(self, *args, end='\n', file=sys.stdout): self._print_func(*args, end=end, flush=True, file=file) def print_rich(self, *args, end='\n', file=sys.stdout): + """ + Prints the specified arguments to the given file object, using rich formatting if available. + + Args: + *args: The arguments to be printed. + end: The string to append at the end of the printed output (default: '\n'). + file: The file object to which the output should be directed (default: sys.stdout). + + Returns: + None + + Raises: + N/A + + Notes: + - If self.use_stdout is True and rich formatting is available, the output will be formatted using rich. + - If self.use_stdout is True and rich formatting is not available, the output will be printed using the default print function. + """ if self.use_stdout: if self._try_rich: print = _get_print_func() diff --git a/src/lumo/utils/memory_grab.py b/src/lumo/utils/memory_grab.py index 8457229..a49c31d 100644 --- a/src/lumo/utils/memory_grab.py +++ b/src/lumo/utils/memory_grab.py @@ -16,10 +16,22 @@ class DeviceMem: + """ + A class that represents device memory usage. + """ def __init__(self): self.line_mem = tree() def _parse_device_pid_mem_pair(self, lines): + """ + Parses device ID, process ID, and memory usage (in MiB) from the given list of strings. + + Args: + lines (List[str]): List of strings to parse. + + Yields: + Tuple[int, int, int]: Tuple containing device ID, process ID, and memory usage (in MiB) for each match found. + """ for lid, line in enumerate(lines): res = re.search(match_mem, line) if res is not None: @@ -28,12 +40,18 @@ def _parse_device_pid_mem_pair(self, lines): yield _device, _pid, _mib def try_parse(self, lines, pid, device): - """ try parse mem from cached lid directly. - Returns: - -1 means failed. - others means successd and its memory. + """ + Attempts to parse memory usage (in MiB) for a process running on a specific device using the cached line ID. + Args: + lines (List[str]): List of strings to parse. + pid (int): Process ID to look for. + device (int or str or torch.device): Device ID to look for. + + Returns: + int: Memory usage in MiB for the specified process and device if found, -1 otherwise. """ + lid = self.line_mem[device][pid] if isinstance(lid, dict): return -1 @@ -51,6 +69,17 @@ def try_parse(self, lines, pid, device): return -1 def re_parse(self, lines, pid, device): + """ + Parses memory usage (in MiB) for a process running on a specific device by searching through the list of strings. + + Args: + lines (List[str]): List of strings to parse. + pid (int): Process ID to look for. + device (int or str or torch.device): Device ID to look for. + + Returns: + int: Memory usage in MiB for the specified process and device if found, 0 otherwise. + """ _res = self.try_parse(lines, pid, device) if _res != -1: return _res @@ -62,11 +91,27 @@ def re_parse(self, lines, pid, device): return 0 def _get_nvidia_smi(self): + """ + Executes the 'nvidia-smi' command and returns the output as a list of strings. + + Returns: + List[str]: List of strings representing the output of the 'nvidia-smi' command. + """ proc = subprocess.Popen(['nvidia-smi'], stdout=subprocess.PIPE) lines = proc.stdout.readlines() return [i.decode() for i in lines] def _device_equal(self, da, db): + """ + Compares two device IDs or names for equality. + + Args: + da (int or str or torch.device): First device ID or name to compare. + db (int or str or torch.device): Second device ID or name to compare. + + Returns: + bool: True if the two device IDs or names are equal, False otherwise. + """ if isinstance(da, (int, str)): da = torch.device(da) if isinstance(db, (int, str)): @@ -74,7 +119,15 @@ def _device_equal(self, da, db): return da == db def get_device_release_mem(self, device): - """ get device memory left.""" + """ + Returns the amount of free memory (in MiB) on a specified device. + + Args: + device (int or str or torch.device): Device ID or name to look up. + + Returns: + int: Amount of free memory (in MiB) on the specified device. + """ s_pid = os.getpid() total = self.get_device_mem(device) for _device, _pid, _mib in self._parse_device_pid_mem_pair(self._get_nvidia_smi()): @@ -84,15 +137,27 @@ def get_device_release_mem(self, device): return total def get_device_mem(self, device): - """ returns device total memory(unit: MB) """ + """ + Returns the total amount of memory (in MiB) on a specified device. + + Args: + device (int or str or torch.device): Device ID or name to look up. + + Returns: + int: Total amount of memory (in MiB) on the specified device. + """ return torch.cuda.get_device_properties(device).total_memory // (1024 * 1024) def get_pid_device_mem(self, pid, device): """ - 尽可能有效率的得到进程在某设备下占用的显存(通过命令行程序调用获取) - :param pid: - :param device: - :return: + Attempts to obtain the memory usage (in MiB) for a specific process running on a specific device. + + Args: + pid (int): Process ID to look up. + device (int or str or torch.device): Device ID or name to look up. + + Returns: + int: Memory usage in MiB for the specified process and device if found, -1 otherwise. """ if isinstance(device, torch.device): device = device.index @@ -115,17 +180,20 @@ def get_pid_device_mem(self, pid, device): class memory(object): - r""" - 优雅的抢卡 + """ + A graceful memory allocator that optimizes GPU memory allocation by incrementally increasing the memory + footprint to minimize fragmentation. + Args: - memory: 需要占用的内存,以 MB 为单位 - device: 需要占用内存的设备 - hold: - unit: - Example:: + memory (int): Memory size to be allocated in MB. + device (str or int, optional): Device to allocate memory on. Defaults to the current CUDA device. + hold (bool, optional): Whether to hold the memory after allocation. Defaults to False. + invade (bool, optional): Whether to use aggressive memory allocation. Defaults to False. + + Examples: >>> import lumo >>> with lumo.memory(5000): - ... y = x * 2 + ... y = x * 2 >>> @lumo.memory(1024) ... def doubler(x): @@ -134,11 +202,21 @@ class memory(object): >>> lumo.memory(10000).start() ... # do something - Why use nvidia-smi to get memory useage? see: + References: + To get GPU memory usage, we use nvidia-smi. Refer to this link for details: https://github.com/pytorch/pytorch/issues/12873 """ def __init__(self, memory, device=None, hold=False, invade=False) -> None: + """ + Initialize the memory allocator. + + Args: + memory (int): Memory size to be allocated in MB. + device (str or int, optional): Device to allocate memory on. Defaults to the current CUDA device. + hold (bool, optional): Whether to hold the memory after allocation. Defaults to False. + invade (bool, optional): Whether to use aggressive memory allocation. Defaults to False. + """ super().__init__() if device is None: device = torch.cuda.current_device() @@ -155,6 +233,14 @@ def __init__(self, memory, device=None, hold=False, invade=False) -> None: self.last_success = _memer.get_pid_device_mem(_pid, self.device) def copy(self, pid: int, wait: bool = True): + """ + Copy memory allocation parameters from another process. + + Args: + pid (int): Process ID to copy memory parameters from. + wait (bool, optional): Whether to wait until the other process has finished before starting allocation. + Defaults to True. + """ self.need = _memer.get_pid_device_mem(pid, self.device) if wait: self.wait(pid) @@ -162,16 +248,25 @@ def copy(self, pid: int, wait: bool = True): self.start() def wait(self, pid): + """ + Wait for the other process to finish before starting allocation. + + Args: + pid (int): Process ID to wait for. + """ while _memer.get_pid_device_mem(pid, self.device) > 0: time.sleep(0.5) self.start() def immediately(self, pre_init=False): """ - 等待,直到内存有空间后,开始申请相应显存,优雅,礼貌,推荐 + Wait until there is enough memory available, then allocate the necessary memory. + This is the recommended way to allocate memory. + Args: - pre_init: 是否初始化 CUDA(这将在一开始消耗一定显存),默认为 False,即不抢占任何内存, - 直到设备释放足够空间后开始抢占。 + pre_init (bool, optional): Whether to initialize CUDA (which will consume a certain amount of memory) + before allocating. Defaults to False, meaning memory will only be allocated after enough memory is + released by other processes. """ while True: _left = _memer.get_device_release_mem(self.device) @@ -202,7 +297,16 @@ def immediately(self, pre_init=False): _left), end='\r') def _malloc(self, size, init=False): - """ unit: mb """ + """ + Allocate memory of the given size (in MB) on the specified device. + + Args: + size (int): Memory size to be allocated, in MB. + init (bool, optional): Whether to initialize CUDA. Defaults to False. + + Returns: + bool: True if the memory allocation was successful, False otherwise. + """ try: tmp = torch.rand(size, 1048576 // 4, device=self.device) if not init: @@ -212,17 +316,33 @@ def _malloc(self, size, init=False): return False def end(self): + """ + Release allocated memory and empty the CUDA cache. + """ del self.mem[:] torch.cuda.empty_cache() def start(self, immediately=True): + """ + Start memory allocation. + + Args: + immediately (bool, optional): Whether to use the recommended memory allocation method. + Defaults to True. + """ if immediately: self.immediately() else: self.invade() def invade(self, unit=5): - """一点一点的侵占,有多少占用多少,直到申请满为止,比较粗鲁,不友好,不推荐""" + """ + Aggressively allocate memory, increasing the memory footprint until the necessary amount is allocated. + This method is not recommended. + + Args: + unit (int, optional): Incremental size of memory allocation (in MB). Defaults to 5. + """ try: while self.last_success < self.need: res = self._malloc(unit + self.acc) @@ -243,15 +363,35 @@ def invade(self, unit=5): print('\nabort.') def __enter__(self): + """ + Start memory allocation when entering the 'with' block. + """ self.start(immediately=not self.is_invade) def __exit__(self, *args): + """ + Release memory and empty the CUDA cache when exiting the 'with' block. + + Returns: + bool: Always returns True. + """ self.end() return True def __call__(self, func): + """ + Decorator to use with functions that require memory allocation. + + Args: + func (callable): The function to decorate. + + Returns: + callable: The decorated function. + """ + @functools.wraps(func) def decorate_no_grad(*args, **kwargs): + """decorate""" with self: return func(*args, **kwargs) @@ -259,6 +399,9 @@ def decorate_no_grad(*args, **kwargs): @staticmethod def hold_current(): + """ + Hold the currently allocated memory. + """ count = torch.cuda.device_count() mems = [_memer.get_pid_device_mem(_pid, i) for i in range(count)] for i, mem in enumerate(mems): diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index 0bebbef..01546de 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -8,6 +8,7 @@ import git from git import Repo, Commit from joblib import hash + from .filelock import Lock @@ -184,23 +185,6 @@ def git_commit(repo=None, key=None, branch_name=None, info: str = None, filter_f return commit_ -def git_checkout(repo=None, commit_hex=None, commit: Commit = None): - if repo is None: - repo = load_repo() - - if commit is None and commit_hex is not None: - commit = repo.commit(commit_hex) - - old_path = os.getcwd() - os.chdir(commit.tree.abspath) - - # with branch(commit.repo, LUMO_BRANCH) as new_branch: - repo.git.checkout('-b', commit.hexsha[:8], commit.hexsha) - - os.chdir(old_path) - return commit.hexsha[:8] - - def git_archive(repo=None, commit_hex=None, commit: Commit = None): """ git archive -o @@ -231,8 +215,49 @@ def git_archive(repo=None, commit_hex=None, commit: Commit = None): return exp +def git_checkout(repo=None, commit_hex=None, commit: Commit = None): + """ + Checkout a specific commit in a Git repository. + + Args: + repo (git.Repo, optional): The Git repository to use. Defaults to None, in which case the repository is loaded using `load_repo()`. + commit_hex (str, optional): The hash of the commit to check out. Defaults to None. + commit (git.Commit, optional): The commit object to check out. Defaults to None. + + Returns: + str: The abbreviated hash of the checked-out commit. + + Raises: + git.InvalidGitRepositoryError: If the specified repository is invalid or not found. + git.BadName: If the specified branch name is invalid or not found. + """ + if repo is None: + repo = load_repo() + + if commit is None and commit_hex is not None: + commit = repo.commit(commit_hex) + + old_path = os.getcwd() + os.chdir(commit.tree.abspath) + + # with branch(commit.repo, LUMO_BRANCH) as new_branch: + repo.git.checkout('-b', commit.hexsha[:8], commit.hexsha) + + os.chdir(old_path) + return commit.hexsha[:8] + + @lru_cache(1) def git_enable(): + """ + Check if Git is installed and a repository is present. + + Returns: + bool: True if Git is installed and a repository is present, False otherwise. + + Raises: + ImportError: If the `gitpython` library is not installed. + """ try: import git except ImportError: diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py index c0d4684..f589b46 100644 --- a/src/lumo/utils/safe_io.py +++ b/src/lumo/utils/safe_io.py @@ -48,6 +48,7 @@ def dump_yaml(obj, fn): def dump_state_dict(obj, fn): + """Saves a PyTorch state dictionary object to disk.""" torch.save(obj, fn) @@ -65,6 +66,7 @@ def load_yaml(fn): def load_state_dict(fn: str, map_location='cpu'): + """Loads a PyTorch model checkpoint from the specified file path and returns its state dictionary.""" ckpt = torch.load(fn, map_location=map_location) return ckpt @@ -76,8 +78,17 @@ def load_text(fn): with open(fn, 'r', encoding='utf-8') as r: return ''.join(r.readlines()) - def dump_text(string: str, fn, append=False): + """Write the given string to a file. + + Args: + string (str): The string to write. + fn (str): The filename to write to. + append (bool, optional): If True, append the string to the file. Otherwise, overwrite the file. Defaults to False. + + Returns: + str: The filename that was written to. + """ mode = 'w' if append: mode = 'a' @@ -87,6 +98,16 @@ def dump_text(string: str, fn, append=False): def safe_getattr(self, key, default=None): + """Get an attribute of an object safely. + + Args: + self (object): The object to get the attribute from. + key (str): The name of the attribute to get. + default (object, optional): The value to return if the attribute does not exist. Defaults to None. + + Returns: + object: The value of the attribute if it exists, otherwise the default value. + """ try: return getattr(self, key, default) except: @@ -94,6 +115,19 @@ def safe_getattr(self, key, default=None): def dump_pkl(obj, file, make_path=True, protocol=None, *, fix_imports=True): + """Save an object to a pickle file. + + Args: + obj (object): The object to save. + file (str or FileIO): The filename or file object to save to. + make_path (bool, optional): If True and file is a filename, create the directory path if it does not exist. Defaults to True. + protocol (int, optional): The pickle protocol to use. Defaults to None. + fix_imports (bool, optional): Whether to fix Python 2 to 3 pickle incompatibilities. Defaults to True. + + Raises: + NotImplementedError: If the file type is not supported. + + """ if isinstance(file, str): if make_path: os.makedirs(os.path.dirname(os.path.abspath(file)), exist_ok=True) @@ -103,10 +137,25 @@ def dump_pkl(obj, file, make_path=True, protocol=None, *, fix_imports=True): elif isinstance(file, FileIO): _pickle.dump(obj, file, protocol=protocol, fix_imports=fix_imports) else: - raise NotImplementedError() + raise NotImplementedError("File type not supported.") def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"): + """Load an object from a pickle file. + + Args: + file (str or FileIO): The filename or file object to load from. + fix_imports (bool, optional): Whether to fix Python 2 to 3 pickle incompatibilities. Defaults to True. + encoding (str, optional): The character encoding to use. Defaults to "ASCII". + errors (str, optional): The error handling scheme to use. Defaults to "strict". + + Returns: + object: The object that was loaded from the file. + + Raises: + NotImplementedError: If the file type is not supported. + + """ if isinstance(file, str): file = open(file, 'rb') res = _pickle.load(file, fix_imports=fix_imports, encoding=encoding, errors=errors) @@ -115,8 +164,7 @@ def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"): elif isinstance(file, FileIO): return _pickle.load(file, fix_imports=fix_imports, encoding=encoding, errors=errors) else: - raise NotImplementedError() - + raise NotImplementedError("File type not supported.") @contextmanager def cached(fn): diff --git a/src/lumo/utils/screen.py b/src/lumo/utils/screen.py index 8c01622..a994a98 100644 --- a/src/lumo/utils/screen.py +++ b/src/lumo/utils/screen.py @@ -52,10 +52,15 @@ def _is_jupyter() -> bool: # pragma: no cover class ScreenStr: """ - A ScreenStr start with '\r' won't overflow, any string outside the screen width will be cut. + A class representing a string that can be displayed on a console screen. + + Attributes: + content (str): The string content to be displayed on the screen. + leftoffset (int): The number of characters to shift the content to the left. Notes: - If output consolo support multiline(like pycharm or jupyter notebook) return, all string will be represented. + A ScreenStr starting with '\r' will not overflow, and any string longer than the screen width will be cut off. + If the console supports multiline output (like PyCharm or Jupyter notebook), all the string will be represented. """ t = 0 dt = 0.7 @@ -70,23 +75,33 @@ class ScreenStr: multi_mode = support_multiline() def __init__(self, content="", leftoffset=0) -> None: + """Initializes a new instance of the ScreenStr class.""" self.content = content ScreenStr.left = leftoffset def __repr__(self) -> str: + """Returns the string representation of the ScreenStr object.""" if ScreenStr.multi_mode: return self.content return self._screen_str() + def __len__(self) -> int: + """Returns the length of the string content.""" + txt = self.content.encode("gbk", errors='ignore') + return len(txt) + def tostr(self): + """Returns the string content.""" return self.content @classmethod def set_speed(cls, dt: float = 0.05): + """Sets the speed of the text scrolling animation.""" cls.dt = dt @classmethod def deltatime(cls): + """Calculates the time elapsed since the last update.""" if cls.last == 0: cls.last = time.time() return 0 @@ -98,6 +113,7 @@ def deltatime(cls): @classmethod def cacu_offset_(cls, out_width): + """Calculates the offset for scrolling the text.""" delta = cls.deltatime() cls.t += delta * cls.dt @@ -115,11 +131,8 @@ def cacu_offset_(cls, out_width): a = 1 - def __len__(self) -> int: - txt = self.content.encode("gbk", errors='ignore') - return len(txt) - def _decode_sub(self, txt, left, right): + """Decodes a part of a byte string to a Unicode string.""" try: txt = txt[left:right].decode("gbk", errors='ignore') except: @@ -135,11 +148,13 @@ def _decode_sub(self, txt, left, right): @staticmethod def consolo_width(): + """Returns the width of the console.""" width = get_consolo_width() return width @staticmethod def split(txt, len): + """Splits a string into two parts.""" try: return txt[:len], txt[len:] except: @@ -149,6 +164,7 @@ def split(txt, len): return txt[:len - 1], txt[len - 1:] def _screen_str(self, margin="..."): + """Returns the string content formatted for display on the screen.""" width = self.consolo_width() txt = self.content.encode("gbk", errors='ignore').strip() @@ -187,7 +203,25 @@ class inlinetqdm(tqdm): """ def full_str(self): + """ + Returns a formatted string representing the full progress bar, including the progress bar itself and any additional information (such as elapsed time or estimated remaining time). + + Args: + None + + Returns: + str: A formatted string representing the full progress bar. + """ return self.format_meter(**self.format_dict) def __str__(self): + """ + Overrides the `__str__` method of the `tqdm` class to display the progress bar as a single-line string. + + Args: + None + + Returns: + str: A single-line string representation of the progress bar. + """ return ScreenStr(self.full_str())._screen_str() From 7ad1f4a2a4bde15eafef266f137fc8b13c28e5d7 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 15:50:19 +0800 Subject: [PATCH 37/99] Add docstr_coverage badge --- images/docstr_coverage_badge.svg | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 images/docstr_coverage_badge.svg diff --git a/images/docstr_coverage_badge.svg b/images/docstr_coverage_badge.svg new file mode 100644 index 0000000..d8ec4de --- /dev/null +++ b/images/docstr_coverage_badge.svg @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + docstr-coverage + docstr-coverage + 65% + 65% + + \ No newline at end of file From 778e1864887a3c0f28d826334399b8e07f271ef8 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 15:50:22 +0800 Subject: [PATCH 38/99] Add docstr_coverage badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 13fa000..31492ea 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![PyPI version](https://badge.fury.io/py/lumo.svg)](https://badge.fury.io/py/lumo) ![Python-Test](https://github.com/pytorch-lumo/lumo/actions/workflows/python-test.yml/badge.svg) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning/blob/master/LICENSE) +[Python-doc](./images/docstr_coverage_badge.svg) `lumo` is a light-weight library to help construct your experiment code, record your experiment results, especially in the field of deep learning. From 830b4a11c7c0425228808957f8c0bd3b146c5571 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 15:51:14 +0800 Subject: [PATCH 39/99] Add docstr_coverage badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 31492ea..8d3e851 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI version](https://badge.fury.io/py/lumo.svg)](https://badge.fury.io/py/lumo) ![Python-Test](https://github.com/pytorch-lumo/lumo/actions/workflows/python-test.yml/badge.svg) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning/blob/master/LICENSE) -[Python-doc](./images/docstr_coverage_badge.svg) +![Python-doc](./images/docstr_coverage_badge.svg) `lumo` is a light-weight library to help construct your experiment code, record your experiment results, especially in the field of deep learning. From 767ea77e1be584f1704566c80d71bddbc1113050 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 18:39:16 +0800 Subject: [PATCH 40/99] Toggle default assert_type to False --- src/lumo/core/attr.py | 4 ++-- src/lumo/core/params.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index bf6c8c5..6f1c8cd 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -36,7 +36,7 @@ def __getitem__(self, key): return get_item_iterative(self, key.split('.')) -def safe_update_dict(src: dict, kwargs: dict, assert_type=True): +def safe_update_dict(src: dict, kwargs: dict, assert_type=False): """ Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner. @@ -56,7 +56,7 @@ def safe_update_dict(src: dict, kwargs: dict, assert_type=True): for ks, v in walk_dict(kwargs): try: old_v = get_item_iterative(src, ks) - if old_v is None or isinstance(old_v, type(v)): + if old_v is None or isinstance(old_v, type(v)) or not assert_type: set_item_iterative(src, ks, v) else: raise TypeError(ks, type(old_v), type(v)) diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index c4fd4df..0f6839c 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -307,7 +307,7 @@ def choice(self, *choices) -> Choices: """ return Choices(choices[0], choices) - def safe_update(self, dic, assert_type=True): + def safe_update(self, dic, assert_type=False): """ Merge `dict` object into the config object, safely updating the values. @@ -357,7 +357,7 @@ def from_json(self, file): Returns: updated `self` object """ - self.safe_update(json.loads(Path(file).read_text()), assert_type=True) + self.safe_update(json.loads(Path(file).read_text()), assert_type=False) return self def from_yaml(self, file): @@ -370,7 +370,7 @@ def from_yaml(self, file): Returns: updated `self` object """ - self.safe_update(dict(OmegaConf.load(file)), assert_type=True) + self.safe_update(dict(OmegaConf.load(file)), assert_type=False) return self def from_args(self, argv: list = None): @@ -386,7 +386,7 @@ def from_args(self, argv: list = None): if argv is None: argv = sys.argv - def func(**kwargs): + def func(*args, **kwargs): if 'help' in kwargs: print(self) exit() From ab1a78ac5f58dce1bd2b19352275fd1edc28192f Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:08:44 +0800 Subject: [PATCH 41/99] Fix problem of update in Params --- src/lumo/core/attr.py | 2 +- src/lumo/core/params.py | 125 ++++++++++++++++++++++++++++++++-- tests/core/test_params.py | 4 +- tests/trainer/test_trainer.py | 27 +++++--- 4 files changed, 140 insertions(+), 18 deletions(-) diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index 6f1c8cd..325cf5c 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -122,7 +122,7 @@ def set_item_iterative(dic: dict, keys: List[str], value): raise ValueError(keys[0], nex) # dict.__setitem__(dic, keys[0], nex) except KeyError: - nex = dict() + nex = {} dict.__setitem__(dic, keys[0], nex) set_item_iterative(nex, keys[1:], value) diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index 0f6839c..7414b83 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -4,14 +4,14 @@ import textwrap from pathlib import Path from pprint import pformat -from typing import Any, List, NewType +from typing import Any, List, NewType, MutableMapping import fire from joblib import hash from omegaconf import DictConfig, OmegaConf, DictKeyType from omegaconf._utils import _ensure_container -from .attr import safe_update_dict, set_item_iterative +# from .attr import safe_update_dict, set_item_iterative from .raises import BoundCheckError # arange_param = namedtuple('arange_param', ['default', 'left', 'right'], defaults=[None, float('-inf'), float('inf')]) @@ -20,6 +20,123 @@ __all__ = ['BaseParams', 'Params', 'ParamsType'] +def safe_update_dict(src: dict, kwargs: dict, assert_type=False): + """ + Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner. + + This function iterates over the items in the kwargs dictionary and updates the corresponding items in the + source dictionary, making sure that the types of the values being updated match the types of the values + already in the source dictionary. + + Args: + src (dict): The dictionary to update. + kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary. + assert_type (bool): A flag indicating whether to check that the types of the values being updated match + the types of the values already in the source dictionary. Defaults to True. + + Returns: + dict: The updated source dictionary. + """ + for ks, v in walk_dict(kwargs): + try: + old_v = get_item_iterative(src, ks) + if old_v is None or isinstance(old_v, type(v)) or not assert_type: + set_item_iterative(src, ks, v) + else: + raise TypeError(ks, type(old_v), type(v)) + except KeyError: + set_item_iterative(src, ks, v) + return src + + +def walk_dict(dic: dict, root=None): + """ + Recursively walks through a dictionary and yields keys and values in a flattened format. + + Args: + - dic (dict): The dictionary to be walked through. + - root (list): The root keys to be used in the resulting flattened format. Defaults to None. + + Yields: + - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary. + + Example: + ```python + d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + for k, v in walk_dict(d): + print(k, v) + # Output: + # (['a', 'b'], 1) + # (['a', 'c', 'd'], 2) + # (['e'], 3) + ``` + """ + if root is None: + root = [] + for k, v in dic.items(): + if isinstance(v, dict): + yield from walk_dict(v, [*root, *k.split('.')]) + else: + yield [*root, *k.split('.')], v + + +def set_item_iterative(dic: dict, keys: List[str], value): + """ + Sets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to update. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + value: The value to set for the nested key. + + Raises: + ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary. + + """ + if len(keys) == 1: + if isinstance(value, MutableMapping): + for ks, v in walk_dict(value): + set_item_iterative(dic, [*keys, *ks], v) + else: + dic.__setitem__(keys[0], value) + else: + try: + nex = dic.__getitem__(keys[0]) + if not isinstance(nex, MutableMapping): + raise ValueError(keys[0], nex) + # dict.__setitem__(dic, keys[0], nex) + except KeyError: + nex = BaseParams() + dic.__setitem__(keys[0], nex) + + set_item_iterative(nex, keys[1:], value) + + +def get_item_iterative(dic: dict, keys: List[str]): + """ + Gets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to retrieve the value from. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + + Raises: + KeyError: If the nested key does not exist in the dictionary. + + Returns: + The value of the nested key in the dictionary. + + """ + if len(keys) == 1: + return dic.__getitem__(keys[0]) + else: + nex = dic.__getitem__(keys[0]) + if isinstance(nex, dict): + return get_item_iterative(nex, keys[1:]) + else: + raise KeyError(keys) + + class Arange: """A class representing a range of numeric values with a default, left and right boundaries. @@ -318,9 +435,7 @@ def safe_update(self, dic, assert_type=False): Returns: None """ - self.update( - safe_update_dict(self.to_dict(), dic, assert_type=assert_type) - ) + safe_update_dict(self, dic, assert_type=assert_type) def from_dict(self, dic: dict): """ diff --git a/tests/core/test_params.py b/tests/core/test_params.py index 5510a78..3b31237 100644 --- a/tests/core/test_params.py +++ b/tests/core/test_params.py @@ -1,9 +1,10 @@ import json import tempfile + from omegaconf import DictConfig -from lumo.core.raises import BoundCheckError from lumo import BaseParams +from lumo.core.raises import BoundCheckError def test_params(): @@ -49,6 +50,7 @@ def get_res(): def test_argv(): params = get_res() params.from_args(['--a', '1', '--d.c.d=2']) + print(params) assert params.a == 1 assert params.d.c.d == 2 assert isinstance(params.kk, DictConfig) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bcb7b4e..fec64ca 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,22 +1,18 @@ -from typing import Union, Optional, Sequence, Mapping, Any - -import numpy as np - -from lumo.proc.config import debug_mode -from lumo.utils.repository import git_dir import os +from typing import Union, Optional, Sequence, Mapping -import tempfile - +import numpy as np import torch from torch import nn from torch.utils.data import DataLoader from lumo import ParamsType, TrainerParams -from lumo import Trainer, DataModule, Meter, TrainStage, MetricType, Record, DatasetBuilder +from lumo import Trainer, DataModule, TrainStage, MetricType, Record, DatasetBuilder from lumo.data.loader import DataLoaderType -from lumo.trainer import callbacks from lumo.proc import glob +from lumo.proc.config import debug_mode +from lumo.trainer import callbacks +from lumo.utils.repository import git_dir def create_dataset_builder(): @@ -171,6 +167,16 @@ def test_trainer(): raise AssertionError(str(trainer.callback_function - trainer.lf.functions)) +def test_trainer_params(): + params = TrainerParams() + params.optim = params.OPTIM.create_optim('SGD', lr=0.9) + params.optim.lr = 3 + print(type(params.optim)) + print(params.optim) + module = nn.Linear(10, 10) + optim = params.optim.build(module.parameters()) + + def test_trainer_state_dict(): trainer = Trainer(TrainerParams()) device_a = trainer.device_a = torch.device('cpu') @@ -193,7 +199,6 @@ def test_trainer_state_dict(): trainer.module = nn.Linear(10, 10) trainer.optim_a = TrainerParams.OPTIM.create_optim('SGD', lr=0.9).build(trainer.module.parameters()) - trainer.load_state_dict(torch.load(fn, map_location='cpu')) assert state_dict['optims']['optim_a'] == optim_a.state_dict() assert all([(i == j).all() From ab000b259e9ef05a6fa691e27dc54c9522e760da Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:11:06 +0800 Subject: [PATCH 42/99] Fix problem of update in Params --- tests/core/test_params.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/core/test_params.py b/tests/core/test_params.py index 3b31237..0c4dc37 100644 --- a/tests/core/test_params.py +++ b/tests/core/test_params.py @@ -1,8 +1,6 @@ import json import tempfile -from omegaconf import DictConfig - from lumo import BaseParams from lumo.core.raises import BoundCheckError @@ -49,12 +47,13 @@ def get_res(): def test_argv(): params = get_res() - params.from_args(['--a', '1', '--d.c.d=2']) + params.from_args(['--a', '1', '--d.c.d=2', '--kk.c=3']) print(params) assert params.a == 1 assert params.d.c.d == 2 - assert isinstance(params.kk, DictConfig) - assert isinstance(params.d.c, DictConfig) + assert isinstance(params.kk, MyParams) + assert params.kk.c == 3 + assert isinstance(params.d.c, BaseParams) def test_dict(): From ad5a70571af78328a6b9513b8b8b8ea2a1257c34 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:20:43 +0800 Subject: [PATCH 43/99] Fix problem of update in Params --- src/lumo/core/params.py | 246 ++++++++++++++++++++-------------------- 1 file changed, 124 insertions(+), 122 deletions(-) diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index 7414b83..25c682f 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -20,123 +20,6 @@ __all__ = ['BaseParams', 'Params', 'ParamsType'] -def safe_update_dict(src: dict, kwargs: dict, assert_type=False): - """ - Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner. - - This function iterates over the items in the kwargs dictionary and updates the corresponding items in the - source dictionary, making sure that the types of the values being updated match the types of the values - already in the source dictionary. - - Args: - src (dict): The dictionary to update. - kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary. - assert_type (bool): A flag indicating whether to check that the types of the values being updated match - the types of the values already in the source dictionary. Defaults to True. - - Returns: - dict: The updated source dictionary. - """ - for ks, v in walk_dict(kwargs): - try: - old_v = get_item_iterative(src, ks) - if old_v is None or isinstance(old_v, type(v)) or not assert_type: - set_item_iterative(src, ks, v) - else: - raise TypeError(ks, type(old_v), type(v)) - except KeyError: - set_item_iterative(src, ks, v) - return src - - -def walk_dict(dic: dict, root=None): - """ - Recursively walks through a dictionary and yields keys and values in a flattened format. - - Args: - - dic (dict): The dictionary to be walked through. - - root (list): The root keys to be used in the resulting flattened format. Defaults to None. - - Yields: - - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary. - - Example: - ```python - d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} - for k, v in walk_dict(d): - print(k, v) - # Output: - # (['a', 'b'], 1) - # (['a', 'c', 'd'], 2) - # (['e'], 3) - ``` - """ - if root is None: - root = [] - for k, v in dic.items(): - if isinstance(v, dict): - yield from walk_dict(v, [*root, *k.split('.')]) - else: - yield [*root, *k.split('.')], v - - -def set_item_iterative(dic: dict, keys: List[str], value): - """ - Sets the value of a nested key in a dictionary using an iterative approach. - - Args: - dic (dict): The dictionary to update. - keys (List[str]): A list of keys representing the path to the nested key in the dictionary. - value: The value to set for the nested key. - - Raises: - ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary. - - """ - if len(keys) == 1: - if isinstance(value, MutableMapping): - for ks, v in walk_dict(value): - set_item_iterative(dic, [*keys, *ks], v) - else: - dic.__setitem__(keys[0], value) - else: - try: - nex = dic.__getitem__(keys[0]) - if not isinstance(nex, MutableMapping): - raise ValueError(keys[0], nex) - # dict.__setitem__(dic, keys[0], nex) - except KeyError: - nex = BaseParams() - dic.__setitem__(keys[0], nex) - - set_item_iterative(nex, keys[1:], value) - - -def get_item_iterative(dic: dict, keys: List[str]): - """ - Gets the value of a nested key in a dictionary using an iterative approach. - - Args: - dic (dict): The dictionary to retrieve the value from. - keys (List[str]): A list of keys representing the path to the nested key in the dictionary. - - Raises: - KeyError: If the nested key does not exist in the dictionary. - - Returns: - The value of the nested key in the dictionary. - - """ - if len(keys) == 1: - return dic.__getitem__(keys[0]) - else: - nex = dic.__getitem__(keys[0]) - if isinstance(nex, dict): - return get_item_iterative(nex, keys[1:]) - else: - raise KeyError(keys) - - class Arange: """A class representing a range of numeric values with a default, left and right boundaries. @@ -437,7 +320,7 @@ def safe_update(self, dic, assert_type=False): """ safe_update_dict(self, dic, assert_type=assert_type) - def from_dict(self, dic: dict): + def from_dict(self, dic: MutableMapping): """ Update the config object from a dictionary. @@ -506,17 +389,19 @@ def func(*args, **kwargs): print(self) exit() return - config = kwargs.get('config') if config is None: config = kwargs.get('c') if config is not None and isinstance(config, str) and os.path.exists(config): - self.from_yaml(config) + if config.endswith('yaml'): + self.from_yaml(config) + elif config.endswith('json'): + self.from_json(config) - dic = {} + dic = BaseParams() for k, v in kwargs.items(): set_item_iterative(dic, k.split('.'), v) - self.safe_update(dic) + self.safe_update(dic.to_dict()) fire.Fire(func, command=argv) return self @@ -626,3 +511,120 @@ class Params(BaseParams): ParamsType = NewType('ParamsType', Params) + + +def safe_update_dict(src: BaseParams, kwargs: dict, assert_type=False): + """ + Updates the source dictionary with the key-value pairs from the kwargs dictionary in a safe manner. + + This function iterates over the items in the kwargs dictionary and updates the corresponding items in the + source dictionary, making sure that the types of the values being updated match the types of the values + already in the source dictionary. + + Args: + src (dict): The dictionary to update. + kwargs (dict): The dictionary containing the new key-value pairs to add to the source dictionary. + assert_type (bool): A flag indicating whether to check that the types of the values being updated match + the types of the values already in the source dictionary. Defaults to True. + + Returns: + dict: The updated source dictionary. + """ + for ks, v in walk_dict(kwargs): + try: + old_v = get_item_iterative(src, ks) + if old_v is None or isinstance(old_v, type(v)) or not assert_type: + set_item_iterative(src, ks, v) + else: + raise TypeError(ks, type(old_v), type(v)) + except KeyError: + set_item_iterative(src, ks, v) + return src + + +def walk_dict(dic: MutableMapping, root=None): + """ + Recursively walks through a dictionary and yields keys and values in a flattened format. + + Args: + - dic (dict): The dictionary to be walked through. + - root (list): The root keys to be used in the resulting flattened format. Defaults to None. + + Yields: + - A tuple containing a list of keys and a value. The list of keys is composed of the root keys and the current keys in the dictionary, split by '.' if there are any. The value is the corresponding value in the dictionary. + + Example: + ```python + d = {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + for k, v in walk_dict(d): + print(k, v) + # Output: + # (['a', 'b'], 1) + # (['a', 'c', 'd'], 2) + # (['e'], 3) + ``` + """ + if root is None: + root = [] + for k, v in dic.items(): + if isinstance(v, dict): + yield from walk_dict(v, [*root, *k.split('.')]) + else: + yield [*root, *k.split('.')], v + + +def set_item_iterative(dic: BaseParams, keys: List[str], value): + """ + Sets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to update. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + value: The value to set for the nested key. + + Raises: + ValueError: If a key in the path exists in the dictionary but the corresponding value is not a dictionary. + + """ + if len(keys) == 1: + if isinstance(value, MutableMapping): + for ks, v in walk_dict(value): + set_item_iterative(dic, [*keys, *ks], v) + else: + dic.__setitem__(keys[0], value) + else: + try: + nex = dic.__getitem__(keys[0]) + if not isinstance(nex, MutableMapping): + raise ValueError(keys[0], nex) + # dict.__setitem__(dic, keys[0], nex) + except KeyError: + nex = BaseParams() + dic.__setitem__(keys[0], nex) + + set_item_iterative(nex, keys[1:], value) + + +def get_item_iterative(dic: MutableMapping, keys: List[str]): + """ + Gets the value of a nested key in a dictionary using an iterative approach. + + Args: + dic (dict): The dictionary to retrieve the value from. + keys (List[str]): A list of keys representing the path to the nested key in the dictionary. + + Raises: + KeyError: If the nested key does not exist in the dictionary. + + Returns: + The value of the nested key in the dictionary. + + """ + if len(keys) == 1: + return dic.__getitem__(keys[0]) + else: + nex = dic.__getitem__(keys[0]) + if isinstance(nex, MutableMapping): + return get_item_iterative(nex, keys[1:]) + else: + raise KeyError(keys) From cf6723a4693a5fedf9e262a8aa8c5dc5c64521fa Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:30:11 +0800 Subject: [PATCH 44/99] Fix problem of update in Params --- tests/core/test_params.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/core/test_params.py b/tests/core/test_params.py index 0c4dc37..fa449f7 100644 --- a/tests/core/test_params.py +++ b/tests/core/test_params.py @@ -68,10 +68,22 @@ def test_json(): fn = tempfile.mktemp() with open(fn, 'w') as w: json.dump({'c': {'a': 2}}, w) + res.c.a = 4 res.from_json(fn) assert res.c.a == 2 +def test_yaml(): + res = get_res() + fn = tempfile.mktemp('.yaml') + res.a = 3 + res.to_yaml(fn) + res.from_args([f'--config={fn}', '--a=2']) + assert res.a == 2 + res.from_args([f'--config={fn}']) + assert res.a == 3 + + def test_copy(): res = get_res() copy = res.copy() From 760b71e6ca56a848d0c44a4b12dc8e095a7467fd Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:30:24 +0800 Subject: [PATCH 45/99] git_commit bug --- src/lumo/utils/repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index 01546de..62470c1 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -155,7 +155,7 @@ def git_commit(repo=None, key=None, branch_name=None, info: str = None, filter_f # print(diff_uncommit) if filter_files is not None: - diff_from_branches = [i.a_path for i in diff_from_branches if i.a_path in filter_files] + diff_from_branches = [i for i in diff_from_branches if i.a_path in filter_files] if len(diff_from_branches) == 0 and len(diff_uncommit) == 0 and len(repo.untracked_files) == 0: commit_ = exp_head_commit From da432bb6d16fdbbfcaf80e595f97efeb8e2c6133 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:40:58 +0800 Subject: [PATCH 46/99] Attr __repr__ problem --- src/lumo/core/attr.py | 12 ++++++------ tests/core/test_attr.py | 10 +++++++++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/lumo/core/attr.py b/src/lumo/core/attr.py index 325cf5c..558e90c 100644 --- a/src/lumo/core/attr.py +++ b/src/lumo/core/attr.py @@ -114,16 +114,16 @@ def set_item_iterative(dic: dict, keys: List[str], value): for ks, v in walk_dict(value): set_item_iterative(dic, [*keys, *ks], v) else: - dict.__setitem__(dic, keys[0], value) + OrderedDict.__setitem__(dic, keys[0], value) else: try: - nex = dict.__getitem__(dic, keys[0]) + nex = OrderedDict.__getitem__(dic, keys[0]) if not isinstance(nex, dict): raise ValueError(keys[0], nex) # dict.__setitem__(dic, keys[0], nex) except KeyError: - nex = {} - dict.__setitem__(dic, keys[0], nex) + nex = Attr() + OrderedDict.__setitem__(dic, keys[0], nex) set_item_iterative(nex, keys[1:], value) @@ -144,9 +144,9 @@ def get_item_iterative(dic: dict, keys: List[str]): """ if len(keys) == 1: - return dict.__getitem__(dic, keys[0]) + return OrderedDict.__getitem__(dic, keys[0]) else: - nex = dict.__getitem__(dic, keys[0]) + nex = OrderedDict.__getitem__(dic, keys[0]) if isinstance(nex, dict): return get_item_iterative(nex, keys[1:]) else: diff --git a/tests/core/test_attr.py b/tests/core/test_attr.py index 0d25076..aed556e 100644 --- a/tests/core/test_attr.py +++ b/tests/core/test_attr.py @@ -1,7 +1,8 @@ -from lumo.core.attr import Attr as attr, set_item_iterative, get_item_iterative import numpy as np import torch +from lumo.core.attr import Attr as attr, set_item_iterative + class NAttr(attr): @@ -22,11 +23,18 @@ def get_res(): res.e = torch.tensor([2, 3, 4]).float() res.f = np.array(2) res.g = np.array([2, 3, 4]) + return res def test_replace(): res = get_res() + print(res) + assert ( + str(res) == "Attr([('a', 1), ('nn', None), ('kk', Attr([('k', None)])), ('st', 'NAttr()'), " + "('b', [2, 3, 4]), ('c', Attr([('a', 1), ('b', [5, 6, 7]), ('c', Attr([('d', [8, 9])]))])), " + "('d', tensor(1.)), ('e', tensor([2., 3., 4.])), ('f', array(2)), ('g', array([2, 3, 4]))])" + ) res.update(a=6, b=[4, 5]) res['c.c.e.f'] = 5 assert res.a == 6 From 4f7b7083d35cd2f87f76ba10c5b4b51981a86e20 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 19:58:08 +0800 Subject: [PATCH 47/99] pass test --- tests/core/test_attr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_attr.py b/tests/core/test_attr.py index aed556e..76fe688 100644 --- a/tests/core/test_attr.py +++ b/tests/core/test_attr.py @@ -46,7 +46,7 @@ def test_replace(): def test_get_set(): - res = {} + res = attr() set_item_iterative(res, ['a', 'b', 'c'], 4) assert isinstance(res['a'], dict) assert isinstance(res['a']['b'], dict) From 9167d912a2a648ccdaf180406713944acd1c036e Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 20:05:18 +0800 Subject: [PATCH 48/99] Add test is_alive for Experiment --- src/lumo/exp/experiment.py | 10 ++++++++++ src/lumo/proc/pid.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 3aeaad8..2fe1a6b 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -450,6 +450,16 @@ def from_disk(cls, path): self.load_prop() return self + @property + def is_alive(self): + pinfo = self.properties['pinfo'] + + hash_obj = runtime_pid_obj(pinfo['pid']) + if hash_obj is None: + return False + + return pid_hash(hash_obj) == pinfo['hash'] + @property def exec_argv(self): execute_info = self.get_prop('execute') diff --git a/src/lumo/proc/pid.py b/src/lumo/proc/pid.py index 23ed726..5427f21 100644 --- a/src/lumo/proc/pid.py +++ b/src/lumo/proc/pid.py @@ -1,7 +1,7 @@ """ Returns information about the specified process or the current process, and computes its hash value. """ -from psutil import Process +from psutil import Process, pid_exists from joblib import hash import os @@ -21,11 +21,15 @@ def runtime_pid_obj(pid=None): """ if pid is None: pid = os.getpid() - p = Process(pid) - obj = { - "pid": p.pid, "pname": p.name(), 'pstart': p.create_time(), 'argv': p.cmdline() - } - return obj + + if pid_exists(pid): + p = Process(pid) + obj = { + "pid": p.pid, "pname": p.name(), 'pstart': p.create_time(), 'argv': p.cmdline() + } + return obj + + return None def pid_hash(pid_obj=None): From 1f4423949e72eb769dcbb3cd3c0b2c8f4a6f0d81 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 21:29:19 +0800 Subject: [PATCH 49/99] Deduplicate table rows --- src/lumo/analyse/collect.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lumo/analyse/collect.py b/src/lumo/analyse/collect.py index 2419382..7cd8920 100644 --- a/src/lumo/analyse/collect.py +++ b/src/lumo/analyse/collect.py @@ -68,14 +68,14 @@ def flatten_metric(df, *keys: str): def collect_table_rows(metric_root=None) -> pd.DataFrame: """Collect all table_row into a pandas.DataFrame""" - res = [] + res = {} logger = Logger() exp_map = list_all_metrics(metric_root) for k, rows in exp_map.items(): # append existing row metrics global_dic = PDict(os.path.join(metricroot(), f'{k}.dict.sqlite')) - for row in global_dic.values(): - res.append(row) + for test_name, row in global_dic.items(): + res[test_name] = row if len(rows) == 0: continue @@ -94,10 +94,10 @@ def collect_table_rows(metric_root=None) -> pd.DataFrame: continue global_dic[test_name] = row shutil.move(row_fn, os.path.join(os.path.dirname(row_fn), f'.{test_name}.pkl')) - res.append(row) + res[test_name] = row global_dic.flush() - return pd.DataFrame(res) + return pd.DataFrame(list(res.values())) def replac(df: pd.DataFrame): From 5e35119c12953ef0de40ee2b661419c2e0c26a72 Mon Sep 17 00:00:00 2001 From: sailist Date: Mon, 6 Mar 2023 22:24:36 +0800 Subject: [PATCH 50/99] File suffix --- src/lumo/core/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index 25c682f..b47f39a 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -393,7 +393,7 @@ def func(*args, **kwargs): if config is None: config = kwargs.get('c') if config is not None and isinstance(config, str) and os.path.exists(config): - if config.endswith('yaml'): + if config.endswith('yaml') or config.endswith('yml'): self.from_yaml(config) elif config.endswith('json'): self.from_json(config) From 4019c734d91e7dbf0df797e3cef2ef2725a245a8 Mon Sep 17 00:00:00 2001 From: sailist Date: Tue, 7 Mar 2023 14:52:23 +0800 Subject: [PATCH 51/99] Fix record update --- src/lumo/core/record.py | 5 ++-- tests/core/test_record.py | 49 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 tests/core/test_record.py diff --git a/src/lumo/core/record.py b/src/lumo/core/record.py index c29f2e3..d17c818 100644 --- a/src/lumo/core/record.py +++ b/src/lumo/core/record.py @@ -94,12 +94,13 @@ def record(self, metric, global_step=None): """Records a metric for the current step.""" meter = wrap_result(metric) agg = meter._avg - + # print(agg) + # print(meter) for k, v in meter.items(): stg = agg.get(k, 'last') item = self._agg.get(k, None) if item is None: - item = ReduceItem(stg) + item = ReduceItem(gb_method=stg) item.update(v) self._agg[k] = item diff --git a/tests/core/test_record.py b/tests/core/test_record.py new file mode 100644 index 0000000..3e51729 --- /dev/null +++ b/tests/core/test_record.py @@ -0,0 +1,49 @@ +import pytest + +from lumo import Record, Meter + + +@pytest.fixture +def record(): + return Record(stage='test') + + +def test_stage(record): + assert record.stage == 'test' + + +def test_record(record): + for i in range(10): + m = Meter() + m.sum.C = 512 + record.record(m) + record.record({'loss': 0.5, 'accuracy': 0.8}) + assert record._agg['loss'].res == 0.5 + assert record._agg['accuracy'].res == 0.8 + + +def test_record_meter(record): + for i in range(10): + m = Meter() + m.sum.C = 512 + record.record(m) + assert record.agg()['C'] == 512 * 10 + # assert record._agg['accuracy'].res == 0.8 + + +def test_clear(record): + record.record({'loss': 0.5, 'accuracy': 0.8}) + record.clear() + assert len(record._agg) == 0 + assert len(record._cache) == 0 + + +def test_flush(record): + record.record({'loss': 0.5, 'accuracy': 0.8}) + record.flush() + assert len(record._cache) == 0 + + +def test_str(record): + record.record({'loss': 0.5, 'accuracy': 0.8}) + assert str(record) == 'loss=0.5, accuracy=0.8' From fecfa600749f392ff1b7382fa4b115ef9d973faf Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 18:58:58 +0800 Subject: [PATCH 52/99] Ignore doc test for callbacks.py --- .docstr.yaml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.docstr.yaml b/.docstr.yaml index 1108d91..9d40de3d 100644 --- a/.docstr.yaml +++ b/.docstr.yaml @@ -15,9 +15,5 @@ ignore_names_file: .*/test # regex #fail_under: 90 # int #percentage_only: True # Boolean ignore_patterns: - .*: - - "on.*end" - - "on.*begin" - - "on.*begin" - - "on_first_exception" - - "on_hook.*" \ No newline at end of file + callbacks: + - ".*" \ No newline at end of file From 45fd30dde5f1fe618a3d0c8553279c8f0c090a6e Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 20:09:02 +0800 Subject: [PATCH 53/99] Added doc string for all core code powered by chatGPT. (except for contrib, sketch, and decorators) --- src/lumo/core/interp.py | 82 +++++- src/lumo/core/params.py | 5 + src/lumo/core/raises.py | 3 +- src/lumo/core/record_backend/abc.py | 41 +++ src/lumo/data/builder.py | 196 ++++++++++++-- src/lumo/data/collate.py | 1 + src/lumo/data/datamodule.py | 9 + src/lumo/data/loader.py | 20 ++ src/lumo/exp/base.py | 70 ++++- src/lumo/exp/experiment.py | 388 ++++++++++++++++++++++++++-- src/lumo/exp/exphook.py | 12 +- src/lumo/exp/finder.py | 140 +++++++++- src/lumo/proc/config.py | 27 +- src/lumo/proc/dependency.py | 16 -- src/lumo/trainer/base.py | 1 + src/lumo/trainer/callbacks.py | 61 +---- 16 files changed, 923 insertions(+), 149 deletions(-) diff --git a/src/lumo/core/interp.py b/src/lumo/core/interp.py index c9f34c8..ad7a2c1 100644 --- a/src/lumo/core/interp.py +++ b/src/lumo/core/interp.py @@ -281,7 +281,15 @@ def __call__(self, cur): class Cos(ABCContinuous): - """one cycle cosine functoin""" + """one cycle cosine functoin + + end -> ,--------- + / + start -> ______/ + ↑ ↑ + left right + + """ @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): @@ -301,7 +309,16 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): class Linear(ABCContinuous): - """linear schedule""" + """linear schedule + + ^ + end | .* + | .* + | .* + |.* + start +-----------------> + left right + """ @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): @@ -342,7 +359,20 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): class Log(ABCContinuous): - """quick to slow""" + """ + quick to slow + + end | * + | * + | * + | * + | * + | * + | * + start |* + ------------------------------------------- + left right + """ @classmethod def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): @@ -365,7 +395,14 @@ def interp(cls, cur, start=0., end=1., left=0., right=1., *args, **kwargs): class Constant(ABCContinuous): - """A scheduler representing a constant value""" + """ + A scheduler representing a constant value + | + constant |-------------- + | + |________________ + ... any ... + """ def __init__(self, value=0.5, *args, **kwargs): super().__init__(start=value, end=value, left=0, right=1, *args, **kwargs) @@ -375,6 +412,13 @@ def __init__(self, value=0.5, *args, **kwargs): class PeriodCos(ABCPeriod): """ periodic cosine schedule + + end -> ,-. ,-. ,-. ,-. + / \ / \ / \ / \ + start -> ______/ \_/ \_/ \_/ \_________ + ratio 0 1 2 3 ..... + \----| + period """ @classmethod @@ -390,10 +434,18 @@ def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs): class PeriodHalfCos(ABCPeriod): """ half periodic cosine schedule, period is (right-left) + + end -> ,- ,- ,- ,- + / / / / + start -> ______/ / / / + ratio 0 1 2 3 ... + \--| + period """ @classmethod def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs): + """interp with period halfcos method""" constant = kwargs.get('constant', False) ratio = cls.ratio(cur, left=left, period=period, constant=constant) cos_ratio = 0.5 * (1 + np.cos(ratio * np.pi)) @@ -401,6 +453,17 @@ def interp(cls, cur, start=0., end=1., left=0., period=1., *args, **kwargs): class PeriodTriangle(ABCPeriod): + """ + A interp class to simulate a periodic triangle waveform. + + + end /\ /\ /\ /\ + start / \/ \/ \/ \ + 0 1 2 3 4 + \--| + period + """ + def __init__(self, start=0, end=1, left=0, left_period=1, right_period=1, *args, **kwargs): super().__init__(start=start, end=end, left=left, period=(left_period + right_period), @@ -428,6 +491,12 @@ def interp(cls, cur, start=0., end=1., left=0., left_period=0., right_period=1., class PeriodLinear(ABCPeriod): """ sawtooth wave, like a period line schedule + end / / / / + / / / / + start / / / / + 0 1 2 3 .... + \----| + period """ @classmethod @@ -458,6 +527,8 @@ def __call__(self, cur): class PowerDecay2(Interpolate): + """A class for implementing Power Decay Interpolation for a given schedule.""" + def __init__(self, start, schedules, gammas): super().__init__() self.start = start @@ -485,6 +556,8 @@ def __call__(self, cur): class InterpolateList(Interpolate): + """Concat different interpolation functions""" + def __init__(self, schedules: List[Interpolate]): super().__init__() self.schedules = schedules @@ -515,6 +588,7 @@ def __repr__(self): return '{}({})'.format(self.__class__.__name__, content) def plot(self, num=1000, left=None, right=None, show=True): + """plot""" if left is None: left = self.left diff --git a/src/lumo/core/params.py b/src/lumo/core/params.py index b47f39a..c3384f9 100644 --- a/src/lumo/core/params.py +++ b/src/lumo/core/params.py @@ -385,6 +385,7 @@ def from_args(self, argv: list = None): argv = sys.argv def func(*args, **kwargs): + """function to process arg list""" if 'help' in kwargs: print(self) exit() @@ -407,11 +408,13 @@ def func(*args, **kwargs): return self def from_hydra(self, config_path, config_name): + """load from hydra config mode""" import hydra hydra.compose() @hydra.main(config_path=config_path, config_name=config_name) def inner(cfg): + """inner function""" return cfg self.update(inner()) @@ -490,6 +493,7 @@ def hash(self) -> str: return hash(self.to_dict()) def iparams(self): + """Initialization method, mostly used in Trainer""" pass @classmethod @@ -507,6 +511,7 @@ def init_from_kwargs(cls, **kwargs): class Params(BaseParams): + """A class representing parameters""" pass diff --git a/src/lumo/core/raises.py b/src/lumo/core/raises.py index c547063..65759d2 100644 --- a/src/lumo/core/raises.py +++ b/src/lumo/core/raises.py @@ -1 +1,2 @@ -class BoundCheckError(BaseException): pass +class BoundCheckError(BaseException): + """Exception raised when a bound check fails.""" diff --git a/src/lumo/core/record_backend/abc.py b/src/lumo/core/record_backend/abc.py index 1a46d63..b1a81b1 100644 --- a/src/lumo/core/record_backend/abc.py +++ b/src/lumo/core/record_backend/abc.py @@ -3,6 +3,15 @@ class RecordBackend(ABC): + """ + Defines an abstract base class for recording and logging data. + + This module provides an abstract base class, RecordBackend, + that defines the interface for recording and logging data. + + Subclasses of RecordBackend must implement the log(), log_image(), log_audio(), and log_video() methods. + """ + def __init__(self, location, *args, **kwargs): self.location = location @@ -10,11 +19,43 @@ def log(self, data: Dict[str, Any], step: Optional[int] = None, commit: Optional[bool] = None, sync: Optional[bool] = None): + """ + Logs data to the backend. + + Args: + data (Dict[str, Any]): A dictionary containing the data to be logged. + step (Optional[int]): The step number associated with the data. + commit (Optional[bool]): Whether to commit the data to storage. + sync (Optional[bool]): Whether to synchronize the data across multiple devices. + """ raise NotImplementedError() def log_image(self, image_array, caption=None): + """ + Logs an image to the backend. + + Args: + image_array (ndarray): A NumPy array representing the image to be logged. + caption (Optional[str]): A caption describing the image. + """ raise NotImplementedError() + def log_audio(self, image_array, caption=None): + """ + Logs an audio clip to the backend. + + Args: + audio_array (ndarray): A NumPy array representing the audio clip to be logged. + caption (Optional[str]): A caption describing the audio clip. + """ raise NotImplementedError() + def log_image(self, image_array, caption=None): + """ + Logs a video clip to the backend. + + Args: + video_array (ndarray): A NumPy array representing the video clip to be logged. + caption (Optional[str]): A caption describing the video clip. + """ raise NotImplementedError() diff --git a/src/lumo/data/builder.py b/src/lumo/data/builder.py index 7893a07..1522bc6 100644 --- a/src/lumo/data/builder.py +++ b/src/lumo/data/builder.py @@ -212,6 +212,12 @@ def __len__(self): return self._prop['__clen__'] def _update_len(self): + """ + Update the length of the dataset. + + Returns: + res (int): The length of the dataset after updating. + """ if not self.sized: self._prop['__clen__'] = None return @@ -233,10 +239,22 @@ def _update_len(self): @property def inputs(self): + """ + Property getter for the input sources. + + Returns: + self._data (dict): The input data. + """ return self._data @property def outputs(self): + """ + Property getter for the output data. + + Returns: + mapping (dict): A mapping of output keys to their corresponding sources. + """ mapping = {} for key, outkeys in self._outs.items(): if key == '::idx::': @@ -250,25 +268,61 @@ def outputs(self): @property def mode(self): + """ + Property getter for the mode of the dataset. + + Returns: + self._prop.get('mode', 'zip'): The mode of the dataset. + """ return self._prop.get('mode', 'zip') @property def iterable(self): + """ + Property getter for the iterability of the dataset. + + Returns: + self._prop.get('iterable', False): Whether the dataset is iterable. + """ return self._prop.get('iterable', False) @property def subindices(self): + """ + Property getter for the sub-indices of the dataset. + + Returns: + self._prop.get('subindices', None): The sub-indices of the dataset. + """ return self._prop.get('subindices', None) @property def pseudo_length(self) -> int: + """ + Property getter for the pseudo-length of the dataset. + + Returns: + self._prop.get('pseudo_length', None): The pseudo-length of the dataset. + """ return self._prop.get('pseudo_length', None) @property def pseudo_repeat(self) -> int: + """ + Property getter for the pseudo-repeat of the dataset. + + Returns: + self._prop.get('pseudo_repeat', None): The pseudo-repeat of the dataset. + """ return self._prop.get('pseudo_repeat', None) def copy(self): + """ + Create a copy of the dataset builder. + + Returns: + builder (DatasetBuilder): a copy of the dataset builder. + """ builder = DatasetBuilder() builder._prop = builtin_copy(self._prop) builder._idx_keys = builtin_copy(self._idx_keys) @@ -280,6 +334,16 @@ def copy(self): return builder def subset(self, indices: Sequence[int], copy=False): + """ + Create a subset of the dataset builder by selecting indices. + + Args: + indices (Sequence[int]): a sequence of indices to select. + copy (bool): whether to create a copy of the dataset builder or modify it in place. + + Returns: + builder (DatasetBuilder): a new dataset builder containing only the selected indices. + """ if copy: builder = self.copy() else: @@ -289,6 +353,15 @@ def subset(self, indices: Sequence[int], copy=False): return builder def scale_to_size(self, size: int): + """ + Scale the dataset builder to a certain size by repeating or randomly truncating the data. + + Args: + size (int): the size to scale the dataset builder to. + + Returns: + self (DatasetBuilder): the scaled dataset builder. + """ assert isinstance(size, int) assert 'pseudo_repeat' not in self._prop assert 'pseudo_length' not in self._prop @@ -298,6 +371,15 @@ def scale_to_size(self, size: int): return self def repeat(self, multiple: int): + """ + Repeat the dataset builder multiple times to increase its size. + + Args: + multiple (int): the number of times to repeat the dataset builder. + + Returns: + self (DatasetBuilder): the repeated dataset builder. + """ assert isinstance(multiple, int) assert 'pseudo_length' not in self._prop assert 'pseudo_repeat' not in self._prop @@ -308,12 +390,13 @@ def repeat(self, multiple: int): def map_index(self, index): """ - Map the raw index to the final index for source data. + Map the raw index to the final index for the source data. + Args: - index: + index: the raw index to map. Returns: - + index (int): the final index for the source data. """ if self.pseudo_length is not None or self.pseudo_repeat is not None: if self.subindices is not None: @@ -328,21 +411,46 @@ def map_index(self, index): @property def sized(self): + """ + Check if the dataset builder is sized. + + Returns: + bool: True if the dataset builder is sized, False otherwise. + """ return self._prop.get('sized', False) def chain(self): + """ + Set the mode of the dataset builder to chain. + + Returns: + self (DatasetBuilder): the dataset builder with the chain mode set. + """ self._prop['mode'] = 'chain' return self def item(self): + """ + Set the mode of the dataset builder to item. + + Returns: + self (DatasetBuilder): the dataset builder with the item mode set. + """ self._prop['mode'] = 'item' return self def zip(self): + """ + Set the mode of the dataset builder to zip. + + Returns: + self (DatasetBuilder): the dataset builder with the zip mode set. + """ self._prop['mode'] = 'zip' return self def _check_source(self, name, source): + """Check the input source for validity and compatibility with the dataset builder.""" # source is sized can be itered # source can be itered not meant it is sizable. if self.subindices is not None: @@ -374,6 +482,15 @@ def _check_source(self, name, source): raise TypeError(f'Source {name} must be an iterable or sized object, but got {type(source)}.') def add_idx(self, name): + """ + Add an index pseudo source to the dataset builder. + + Args: + name (str): the name of the index. + + Returns: + self (DatasetBuilder): the dataset builder with the index added. + """ outkeys = self._outs.setdefault(f"::idx::", []) assert name not in self._outkeys, f'Output key {name} duplicated.' outkeys.append(name) @@ -382,12 +499,15 @@ def add_idx(self, name): def add_input(self, name: str, source, transform: SingleValueTransform = None): """ - Register a input source with the transform (if provided). + Add an input source to the dataset builder. + Args: - name: source name - source: source, should be a sized object. - transform: + name (str): the name of the input source. + source: the input source to add. + transform (SingleValueTransform): the transform to apply to the input source. + Returns: + self (DatasetBuilder): the dataset builder with the input source added. Notes: Iterable object without `__len__` method currently are not well-tested. Be careful to use them in DatasetBuilder. @@ -401,20 +521,31 @@ def add_input(self, name: str, source, transform: SingleValueTransform = None): return self def add_input_transform(self, name: str, transform: SingleValueTransform = None): + """ + Add a transform to an existing input source. + + Args: + name (str): the name of the input source to add the transform to. + transform (SingleValueTransform): the transform to add. + + Returns: + self (DatasetBuilder): the dataset builder with the transform added. + """ assert name in self._data, f'Source {name} should be added.' warnings.warn('`add` may cause confusion, use set_input_transform ') return self.set_input_transform(name, transform) def add_output(self, name: str, outkey: str, transform: SingleValueTransform = None): """ - Add a data flow from inputs[name] to outputs[outkey] with the transform (if provided). + Add an output flow from input source to the dataset builder. + Args: - name: source name of inputs - outkey: output name of output - transform: a callable function + name (str): the name of the input source for the output. + outkey (str): the name of the output. + transform (SingleValueTransform): the transform to apply to the output. Returns: - + self (DatasetBuilder): the dataset builder with the output added. """ assert name in self._data, f'Must have data source {name} first.' @@ -429,37 +560,56 @@ def add_output(self, name: str, outkey: str, transform: SingleValueTransform = N def add_output_transform(self, outkey: str, transform: SingleValueTransform = None): """ - Add or **replace** transform of the output name. + Add a transform to an existing output. + Args: - outkey: output name. - transform: a callable function + outkey (str): the name of the output to add the transform to. + transform (SingleValueTransform): the transform to add. + + Returns: + self (DatasetBuilder): the dataset builder with the transform added. """ assert outkey in self._outkeys, f'Output key {outkey} should be added.' warnings.warn('add may cause confusion, use set_output_transform ') return self.set_output_transform(outkey, transform) def add_global_transform(self, transform: DictTransform): + """ + Add a global transform to the dataset builder. + + Args: + transform (DictTransform): the global transform to apply to the dataset. + + Returns: + self (DatasetBuilder): the dataset builder with the global transform added. + """ self._transforms['::global::'] = transform return self def set_input_transform(self, name, transform: SingleValueTransform = None): """ - Add or **replace** transform of the input source {name}. + Set the transform for an input source. + Args: - name: source name. - transform: a callable function + name (str): the name of the input source to set the transform for. + transform (SingleValueTransform): the transform to set. + Returns: + self (DatasetBuilder): the dataset builder with the transform set. """ self._transforms[name] = transform return self def set_output_transform(self, outkey, transform: SingleValueTransform = None): """ - Add or **replace** transform of the output {name}. + Set the transform for an output. + Args: - outkey: output name. - transform: a callable function + outkey (str): the name of the output to set the transform for. + transform (SingleValueTransform): the transform to set. + Returns: + self (DatasetBuilder): the dataset builder with the transform set. """ self._transforms[f'::{outkey}'] = transform return self @@ -474,4 +624,6 @@ def __getattribute__(self, item): return res def get_source(self, name): - return self._data[name] + """Get the input source with the given name.""" + warnings.warn('Use inputs[name] instead.') + return self.inputs[name] diff --git a/src/lumo/data/collate.py b/src/lumo/data/collate.py index e786083..3dd301c 100644 --- a/src/lumo/data/collate.py +++ b/src/lumo/data/collate.py @@ -132,6 +132,7 @@ def _filter_none(self, item): return True def before_collate(self, sample_list): + """ before collate""" return list(filter(self._filter_none, sample_list)) diff --git a/src/lumo/data/datamodule.py b/src/lumo/data/datamodule.py index eae4d83..eee3214 100644 --- a/src/lumo/data/datamodule.py +++ b/src/lumo/data/datamodule.py @@ -107,6 +107,15 @@ def regist_dataloader(self, train=None, test=None, val=None, **kwargs): """ def regist_dataloader(self, **kwargs: dict): + """ + Registers the given dataloaders under the given keys. + + Args: + train: A DataLoaderType object for the train set. + test: A DataLoaderType object for the test set. + val: A DataLoaderType object for the validation set. + **kwargs: A DataLoaderType object for other stage + """ for k, v in kwargs.items(): self.prop[k] = v diff --git a/src/lumo/data/loader.py b/src/lumo/data/loader.py index f887200..db11d35 100644 --- a/src/lumo/data/loader.py +++ b/src/lumo/data/loader.py @@ -6,6 +6,7 @@ class LumoDataLoader(DataLoader): + """This module defines the LumoDataLoader class that inherits from the DataLoader class.""" pass @@ -92,18 +93,29 @@ def __init__(self): @property def dataset(self): + """Returns a dictionary that maps the name of the DataLoader to its corresponding dataset.""" return {k: v.dataset for k, v in self.source.items()} @property def source(self): + """Returns the _loaders dictionary.""" return self._loaders def add(self, name, loader: DataLoader, cycle=False): + """ + Adds a DataLoader instance to the _loaders dictionary. + Args: + name (str): The name of the DataLoader. + loader (DataLoader): The DataLoader instance to be added. + cycle (bool): A boolean indicating whether the DataLoader should be cycled. Defaults to False. + + """ self._loaders[name] = loader self._cycle[name] = cycle return self def copy(self): + """Returns a new DataLoaderSide instance with the same _loaders, _cycle, and _state attributes as the original.""" loader = DataLoaderSide() loader._loaders = self._loaders loader._cycle = self._cycle @@ -111,18 +123,26 @@ def copy(self): return loader def zip(self): + """Sets the _state attribute to 'zip', which means the batches are zipped together. + If _state is 'zip', the batches are returned as an ordered dictionary.""" self._state = 'zip' return self def chain(self): + """ + Sets the _state attribute to 'chain', which means the batches are concatenated. + If _state is 'chain', the batches are returned as a list. + """ self._state = 'chain' return self def __len__(self): + """Returns the minimum length of all the DataLoaders that do not have the cycle flag set to True.""" valid_keys = [k for k, cycle in self._cycle.items() if not cycle] return min([len(self._loaders[k]) for k in valid_keys]) def __iter__(self): + """Returns an iterator that generates batches from the DataLoaders in the _loaders dictionary.""" iters = {k: iter(v) for k, v in self._loaders.items()} stop = None diff --git a/src/lumo/exp/base.py b/src/lumo/exp/base.py index 2c90810..fbfb056 100644 --- a/src/lumo/exp/base.py +++ b/src/lumo/exp/base.py @@ -18,21 +18,81 @@ def __new__(cls): @property def config_name(self): + """Get the configuration name for the hook. + + Returns: + A string representing the configuration name for the hook. + + """ return f'HOOK_{self.name.upper()}' @property def config_string(self): + """Get the configuration string for the hook. + + Returns: + A string representing the configuration string for the hook. + + """ + return ', '.join(f'{k}={glob.get(k, v)}' for k, v in self.configs.items()) - def regist(self, exp): self.exp = exp + def regist(self, exp): + """Register the hook with an experiment. + + Args: + exp: The experiment to register the hook with. + + """ + self.exp = exp + + def on_start(self, exp, *args, **kwargs): + """Execute when the experiment starts. + + Args: + exp: The experiment that has started. + *args: Any additional arguments passed to the method. + **kwargs: Any additional keyword arguments passed to the method. - def on_start(self, exp, *args, **kwargs): pass + """ - def on_end(self, exp, end_code=0, *args, **kwargs): pass + def on_end(self, exp, end_code=0, *args, **kwargs): + """Execute when the experiment ends. - def on_progress(self, exp, step, *args, **kwargs): pass + Args: + exp: The experiment that has ended. + end_code (int): The exit code for the experiment. + *args: Any additional arguments passed to the method. + **kwargs: Any additional keyword arguments passed to the method. - def on_newpath(self, exp, *args, **kwargs): pass + """ + + def on_progress(self, exp, step, *args, **kwargs): + """Execute when the experiment makes progress. + + Args: + exp: The experiment that is making progress. + step: The current step of the experiment. + *args: Any additional arguments passed to the method. + **kwargs: Any additional keyword arguments passed to the method. + + """ + + def on_newpath(self, exp, *args, **kwargs): + """Execute when the experiment creates a new path. + + Args: + exp: The experiment that is creating a new path. + *args: Any additional arguments passed to the method. + **kwargs: Any additional keyword arguments passed to the method. + + """ def __repr__(self): + """Return a string representation of the hook. + + Returns: + A string representation of the hook. + + """ return f"Hook(name={self.__class__.name}, switch={self.config_name}, {self.config_string})" diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 2fe1a6b..7dee762 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -19,6 +19,15 @@ def checkdir(path: Union[Path, str]): + """ + Create a directory at the specified path if it does not already exist. + + Args: + path (Union[Path, str]): The path to the directory to be created. + + Returns: + Path: The Path object representing the created directory. + """ if isinstance(path, str): os.makedirs(path, exist_ok=True) elif isinstance(path, Path): @@ -28,7 +37,10 @@ def checkdir(path: Union[Path, str]): class Experiment: """ - (by default), the directory structure is as following: + Represents an experiment and manages its directory structure. An experiment consists of multiple tests, each of which + has its own directory to store information related to that test. + + (By default), the directory structure is as following: .lumo (libroot) - progress - ".{pid}" -> hash @@ -64,6 +76,18 @@ class Experiment: """ def __init__(self, exp_name: str, root=None): + """ + Initializes a new instance of the Experiment class. + + Args: + exp_name (str): The name of the experiment. This should be a legal filename and contain only letters or + underscores. + root (str, optional): The root directory where the experiment's directories will be created. Defaults to + None, in which case the root directory is set to the library's home directory. + + Raises: + ValueError: If the experiment name is not a legal filename. + """ if not can_be_filename(exp_name): raise ValueError(f'Experiment name should be a ligal filename(bettor only contain letter or underline),' f'but got {exp_name}.') @@ -79,19 +103,34 @@ def __init__(self, exp_name: str, root=None): @property def exp_name(self): + """ + str: Gets the name of the experiment. + """ return self._prop['exp_name'] @property def _test_name(self): + """ + str: Gets the name of the current test being run. + """ return self._prop.get('test_name', None) @_test_name.setter def _test_name(self, value): + """ + Sets the name of the current test being run. + + Args: + value (str): The name of the current test. + """ self._prop['test_name'] = value @property def test_name_with_dist(self): """ + str: Gets the name of the current test with the local rank number + appended to it if running in distributed mode. + Create different test_name for each process in multiprocess training. Returns: in main process, will just return test_name itself, in subprocess, "{test_name}.{local_rank()}" @@ -107,7 +146,11 @@ def test_name_with_dist(self): @property def test_name(self): - """Assign unique space(directory) for this test""" + """ + str: Gets the name of the current test being run. + + If the test name is not set, generates a new unique name and sets it. + """ if self._test_name is None: if is_dist(): # if train in distribute mode, subprocess will wait a few seconds to wait main process. flag_fn = f'.{os.getppid()}' @@ -129,7 +172,11 @@ def test_name(self): def _create_test_name(self): """ - [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t + Generates a unique test name based on the current date and time. + regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t + + Returns: + str: The generated test name. """ from lumo.proc.date import timehash from ..utils.fmt import strftime @@ -141,40 +188,76 @@ def _create_test_name(self): @property def root_branch(self): + """ + Path: Gets the root branch directory of the experiment. + """ val = self._root return checkdir(val) @property def lib_root(self): + """ + str: Gets the path of the library's root directory. + """ return self.root_branch.as_posix() @property def exp_branch(self): + """ + Path: Gets the experiment branch directory. + """ val = Path(exproot()).joinpath(self.exp_name) return checkdir(val) @property def blob_branch(self): + """ + Path: Gets the blob branch directory, which is used to store big binary files like model state dicts. + """ val = Path(blobroot()).joinpath(self.exp_name, self.test_name) return checkdir(val) @property def progress_branch(self): + """ + Path: Gets the progress branch directory, which is used to store progress information about running processes. + """ val = Path(progressroot()) return checkdir(val) @property def test_branch(self): + """ + Path: Gets the test branch directory, which is used to store information related to the current test being run. + """ val = self.exp_branch.joinpath(self.test_name) return checkdir(val) def dump_progress(self, ratio: float, update_from=None): + """ + Saves progress information about the experiment. + + Args: + ratio (float): The progress ratio as a number between 0 and 1. + update_from: The process from which the progress update came from. + """ res = {'ratio': max(min(ratio, 1), 0)} if update_from is None: res['update_from'] = update_from self.dump_info('progress', res, append=True) def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop=True): + """ + Saves information about the experiment to a file. + + Args: + key (str): The key under which the information will be stored. + info (Any): The information to store. + append (bool, optional): Whether to append to the file or overwrite it. Defaults to False. + info_dir (str, optional): The name of the directory where the file will be stored. Defaults to 'info'. + set_prop (bool, optional): Whether to set the experiment property with the same key to the saved information. + Defaults to True. + """ fn = self.test_file(f'{key}.json', info_dir) if append: old_info = self.load_info(key, info_dir=info_dir) @@ -185,17 +268,43 @@ def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop io.dump_json(info, fn) def load_info(self, key: str, info_dir='info'): + """ + Loads information about the experiment from a file. + + Args: + key (str): The key under which the information is stored. + info_dir (str, optional): The name of the directory where the file is stored. Defaults to 'info'. + + Returns: + Any: The information stored under the specified key. + """ fn = self.test_file(f'{key}.json', info_dir) if not os.path.exists(fn): return {} return io.load_json(fn) def dump_string(self, key: str, info: str): + """ + Saves a string to a file. + + Args: + key (str): The key under which the string will be stored. + info (str): The string to store. + """ fn = self.test_file(f'{key}.str', 'text') io.dump_text(info, fn) self.set_prop(key, info) def load_string(self, key: str): + """ + Loads a string from a file. + + Args: + key (str): The key under which the string is stored. + + Returns: + str: The string stored under the specified key. + """ fn = self.test_file(f'{key}.str', 'text') if not os.path.exists(fn): return '' @@ -203,6 +312,9 @@ def load_string(self, key: str): @property def tags(self): + """ + dict: Gets the tags associated with the experiment. + """ tags = {} for path in self.test_branch.joinpath('tags').glob('tag.*.json'): ptags = io.load_json(path.as_posix()) # type: dict @@ -210,90 +322,155 @@ def tags(self): return tags def add_tag(self, tag: str, name_space: str = 'default'): + """ + Adds a tag to the experiment. + Args: + tag (str): The tag to add. + name_space (str, optional): The namespace under which to + add the tag. Defaults to 'default'. + """ self.dump_info(f'tag.{name_space}', { tag: None }, append=True, info_dir='tags', set_prop=False) def exp_file(self, filename, *args): """ + Gets the path to a file in the experiment directory. Args: - filename: - *args: - mkdir: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. Returns: - + str: The path to the specified file. """ parent = self.exp_branch.joinpath(*args) return checkdir(parent).joinpath(filename).as_posix() def test_file(self, filename, *args): + """ + Gets the path to a file in the test directory. + + Args: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. + + Returns: + str: The path to the specified file. + """ parent = self.test_branch.joinpath(*args) return checkdir(parent).joinpath(filename).as_posix() def exp_dir(self, *args): """ + Gets the path to a directory in the experiment directory. Args: - filename: - *args: - mkdir: + *args: Any subdirectory names to include in the directory path. Returns: - + str: The path to the specified directory. """ parent = self.exp_branch.joinpath(*args) return checkdir(parent).as_posix() def root_file(self, filename, *args): + """ + Gets the path to a file in the library's root directory. + + Args: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. + + Returns: + str: The path to the specified file. + """ parent = self.root_branch.joinpath(*args) return checkdir(parent).joinpath(filename).as_posix() def root_dir(self, *args): """ + Gets the path to a directory in the library's root directory. Args: - filename: - *args: - mkdir: + *args: Any subdirectory names to include in the directory path. Returns: - + str: The path to the specified directory. """ parent = self.root_branch.joinpath(*args) return checkdir(parent).as_posix() def test_dir(self, *args): + """ + Gets the path to a directory in the test directory. + + Args: + *args: Any subdirectory names to include in the directory path. + + Returns: + str: The path to the specified directory. + """ parent = self.test_branch.joinpath(*args) return checkdir(parent).as_posix() def blob_file(self, filename, *args): + """ + Gets the path to a file in the blob directory. + + Args: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. + + Returns: + str: The path to the specified file. + """ parent = self.blob_branch.joinpath(*args) return checkdir(parent).joinpath(filename).as_posix() def progress_file(self, filename): + """ + Gets the path to a file in the progress directory. + + Args: + filename (str): The name of the file. + + Returns: + str: The path to the specified file. + """ return self.progress_branch.joinpath(filename).as_posix() def blob_dir(self, *args): """ - + Gets the path to a directory in the blob directory. Args: - filename: - *args: - mkdir: + *args: Any subdirectory names to include in the directory path. Returns: - + str: The path to the specified directory. """ parent = self.blob_branch.joinpath(*args) return checkdir(parent).as_posix() def __enter__(self): + """ + Starts the experiment when the Experiment object is used as a context manager using the 'with' statement. + + Returns: + Experiment: The Experiment object. + """ self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): + """ + Ends the experiment when the 'with' statement block exits. + + Args: + exc_type (type): The type of the exception that occurred, if any. + exc_val (Exception): The exception object that was raised, if any. + exc_tb (traceback): The traceback object for the exception, if any. + """ extra = {} if exc_type is not None: exc_type = traceback.format_exception_only(exc_type, exc_val)[-1].strip() @@ -305,14 +482,24 @@ def __exit__(self, exc_type, exc_val, exc_tb): @call_on_main_process_wrap def add_exit_hook(self, func): + """ + Registers a function to be called when the program exits. + + Args: + func (callable): The function to register. + """ import atexit def exp_func(): + """Function executed before process exit.""" func(self) atexit.register(exp_func) @call_on_main_process_wrap def initial(self): + """ + Initializes the experiment by setting up progress, information, and PID tracking. + """ self.add_tag(self.__class__.__name__, 'exp_type') self.dump_progress(0) self.dump_info('execute', { @@ -333,6 +520,9 @@ def initial(self): @call_on_main_process_wrap def start(self): + """ + Starts the experiment. + """ if self.get_prop('start', False): return self.initial() @@ -343,6 +533,14 @@ def start(self): @call_on_main_process_wrap def end(self, end_code=0, *args, **extra): + """ + Ends the experiment. + + Args: + end_code (int): The exit code to set for the experiment. + *args: Additional arguments to pass to the end hooks. + **extra: Additional keyword arguments to pass to the end hooks. + """ if not self.get_prop('start', False): return if self.get_prop('end', False): @@ -355,54 +553,139 @@ def end(self, end_code=0, *args, **extra): @property def repo_name(self): - """repository name""" + """ + Gets the name of the repository associated with the experiment. + + Returns: + str: The name of the repository. + """ return self.project_name @property def project_name(self): - """same as repository name, directory name of project root""" + """ + Gets the name of the project associated with the experiment. + + Returns: + str: The name of the project. + """ return os.path.basename(self.project_root) @property def project_root(self): + """ + Gets the path to the root directory of the project associated with the experiment. + + Returns: + str: The path to the root directory of the project. + """ return local_dir() @property def exp_root(self): - """path to multiple tests of this experiment""" + """ + Gets the path to the directory containing the tests for the experiment. + + Returns: + str: The path to the experiment root directory. + """ return self.exp_branch.as_posix() @property def test_root(self): - """path to record information of one experiment""" + """ + Gets the path to the directory containing information about the current test. + + Returns: + str: The path to the test root directory. + """ return self.test_branch.as_posix() @property def blob_root(self): - """path to storing big binary files""" + """ + Gets the path to the directory containing large binary files associated with the experiment. + + Returns: + str: The path to the blob root directory. + """ return self.blob_branch.as_posix() def __getitem__(self, item): + """ + Gets a property of the experiment. + + Args: + item (str): The name of the property to get. + + Returns: + Any: The value of the property. + """ return self._prop[item] def __setitem__(self, key, value): + """ + Sets a property of the experiment. + + Args: + key (str): The name of the property to set. + value (Any): The value to set the property to. + """ self._prop[key] = value def get_prop(self, key, default=None): + """ + Gets the value of a property of the experiment. + + Args: + key (str): The name of the property to get. + default (Any, optional): The default value to return if the property does not exist. Defaults to None. + + Returns: + Any: The value of the property, or the default value if the property does not exist. + """ return self._prop.get(key, default) def has_prop(self, key): + """ + Determines whether the experiment has a certain property. + + Args: + key (str): The name of the property to check for. + + Returns: + bool: True if the experiment has the property, False otherwise. + """ return key in self._prop def set_prop(self, key, value): + """ + Sets a property of the experiment. + + Args: + key (str): The name of the property to set. + value (Any): The value to set the property to. + """ self._prop[key] = value @property def properties(self): + """ + Gets a dictionary containing all properties of the experiment. + + Returns: + dict: A dictionary containing all properties of the experiment. + """ return self._prop @property def paths(self) -> dict: + """ + Gets a dictionary containing the paths to various directories associated with the experiment. + + Returns: + dict: A dictionary containing the paths to various directories associated with the experiment. + """ return { 'root': self.root_branch.as_posix(), 'exp_root': self.exp_root, @@ -412,10 +695,22 @@ def paths(self) -> dict: @property def enable_properties(self) -> set: + """ + Gets a set of the names of all properties that have been set for the experiment. + + Returns: + set: A set of the names of all properties that have been set for the experiment. + """ return set(self._prop.keys()) @call_on_main_process_wrap def set_hook(self, hook: BaseExpHook): + """ + Registers a hook to be executed during the experiment. + + Args: + hook (BaseExpHook): The hook to register. + """ hook.regist(self) if not glob.get(hook.config_name, True): self.dump_info(hook.name, { @@ -429,6 +724,9 @@ def set_hook(self, hook: BaseExpHook): return self def load_prop(self): + """ + Loads all properties associated with the experiment from disk. + """ for f in os.listdir(self.test_dir('info')): key = os.path.splitext(f)[0] self.set_prop(key, self.load_info(key)) @@ -439,6 +737,18 @@ def load_prop(self): @classmethod def from_disk(cls, path): + """ + Creates an Experiment object from a test root directory on disk. + + Args: + path (str): The path to the test root directory. + + Returns: + Experiment: An Experiment object created from the test root directory. + + Raises: + ValueError: If the path is not a valid test root directory. + """ from .finder import is_test_root if not is_test_root(path): raise ValueError(f'{path} is not a valid test_root') @@ -452,6 +762,12 @@ def from_disk(cls, path): @property def is_alive(self): + """ + Determines whether the process associated with the experiment is still running. + + Returns: + bool: True if the process is still running, False otherwise. + """ pinfo = self.properties['pinfo'] hash_obj = runtime_pid_obj(pinfo['pid']) @@ -462,6 +778,12 @@ def is_alive(self): @property def exec_argv(self): + """ + Gets the arguments used to execute the script associated with the experiment. + + Returns: + List[str]: A list of arguments used to execute the script. + """ execute_info = self.get_prop('execute') try: return [os.path.basename(execute_info['exec_bin']), *execute_info['exec_argv']] @@ -469,13 +791,29 @@ def exec_argv(self): return [] def __repr__(self): + """ + Returns a string representation of the Experiment object. + + Returns: + str: A string representation of the Experiment object. + """ return f'{self.exp_name}->({self.test_name})' def __str__(self): + """ + Returns a string representation of the Experiment object. + + Returns: + str: A string representation of the Experiment object. + """ return self.__repr__() class SimpleExperiment(Experiment): + """ + A simple to use experiment subclass that extends the base `Experiment` class and sets up some useful hooks to + execute before and after the experiment. + """ def __init__(self, exp_name: str, root=None): super().__init__(exp_name, root) diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index ae6a378..86c1434 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -194,15 +194,17 @@ class LockFile(ExpHook): def on_start(self, exp: Experiment, *args, **kwargs): basic = get_lock('lumo', - 'numpy', 'joblib', + 'fire', 'psutil', - 'decorator', - 'torch', - 'numpy', 'accelerate', 'hydra', - 'omegaconf', ) + 'omegaconf', + 'decorator', + + 'numpy', + 'torch', + ) if basic['torch'] is not None: import torch if torch.cuda.is_available(): diff --git a/src/lumo/exp/finder.py b/src/lumo/exp/finder.py index 28c63b4..0f5c53f 100644 --- a/src/lumo/exp/finder.py +++ b/src/lumo/exp/finder.py @@ -9,7 +9,7 @@ """ from pprint import pformat import os -from typing import List, Dict +from typing import List, Dict, Any from lumo.proc.path import libhome, exproot, metricroot from lumo.utils.fmt import indent_print @@ -18,32 +18,86 @@ from . import Experiment -def list_experiment_paths(exp_root=None): +def list_experiment_paths(exp_root=None) -> List[str]: + """ + Returns a list of experiment paths under exp_root directory. + + Args: + exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory. + + Returns: + A list of experiment paths. + """ if exp_root is None: exp_root = exproot() return [os.path.join(exp_root, i) for i in os.listdir(exp_root)] -def _get_exp_name(exp_path: str): +def _get_exp_name(exp_path: str) -> str: + """ + Returns the name of the experiment directory. + + Args: + exp_path: The path to the experiment directory. + + Returns: + The name of the experiment directory. + """ return os.path.basename(exp_path.rstrip('/')) def list_all(exp_root=None) -> Dict[str, List[Experiment]]: + """ + Returns a dictionary of all experiments under exp_root directory. + + Args: + exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory. + + Returns: + A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects. + """ return { _get_exp_name(exp_path): retrieval_tests_from_experiment(exp_path) for exp_path in list_experiment_paths(exp_root) } -def retrieval_tests_from_experiment(exp_path) -> List[Experiment]: +def retrieval_tests_from_experiment(exp_path: str) -> List[Experiment]: + """ + Returns a list of Experiment objects found in the specified experiment directory. + + Args: + exp_path: The path to the experiment directory. + + Returns: + A list of Experiment objects. + """ return [retrieval_experiment(os.path.join(exp_path, f)) for f in os.listdir(exp_path)] -def list_test_names_from_experiment(experiment_name): +def list_test_names_from_experiment(experiment_name: str) -> List[str]: + """ + Returns a list of test names found in the specified experiment directory. + + Args: + experiment_name: The name of the experiment directory. + + Returns: + A list of test names. + """ return os.listdir(os.path.join(exproot(), experiment_name)) -def find_path_from_test_name(test_name: str): +def find_path_from_test_name(test_name: str) -> str: + """ + Returns the path of the specified test name. + + Args: + test_name: The name of the test. + + Returns: + The path of the test, or None if not found. + """ if not is_test_name(test_name): return None @@ -58,21 +112,42 @@ def find_path_from_test_name(test_name: str): return None -def is_test_name(test_name: str): +def is_test_name(test_name: str) -> bool: """ - ^[0-9]{6}.[0-9]{3}.[a-z0-9]{2}t$ + Determines if the specified string is a valid test name. + + Args: + test_name: The string to check. + + Returns: + True if the string is a valid test name, False otherwise. """ return re.search(r'^\d{6}\.\d{3}\.[a-z\d]{2}t$', test_name) is not None -def is_test_root(path: str): +def is_test_root(path: str) -> bool: + """ + Determines if the specified path is a valid test root. + + Args: + path: The path to check. + + Returns: + True if the path is a valid test root, False otherwise. + """ test_name = os.path.basename(path.rstrip('/')) return is_test_name(test_name) -def retrieval_test_root(test_flag: str): +def retrieval_test_root(test_flag: str) -> str: """ - test_flag can be a name like `230214.037.62t` or path like `path/to/230214.037.62t` + Returns the test root directory for the specified test name or test root. + + Args: + test_flag: The test name or test root. + like `230214.037.62t` or path like `path/to/230214.037.62t` + Returns: + The test root directory, or None if not found. """ if is_test_name(test_flag): test_root = find_path_from_test_name(test_flag) @@ -87,6 +162,20 @@ def retrieval_test_root(test_flag: str): def retrieval_experiment(test_name=None, test_root: str = None): + """ + Loads an Experiment object from disk for the given test name or test root. + + Args: + test_name (str, optional): The name of the test to load. If not provided, + the test root directory must be provided instead. Defaults to None. + test_root (str, optional): The root directory of the test to load. If not + provided, the root directory is determined from the test name using + the retrieval_test_root function. Defaults to None. + + Returns: + Optional[Experiment]: The loaded Experiment object, or None if the test + root cannot be determined or the Experiment cannot be loaded from disk. + """ if test_root is None: test_root = retrieval_test_root(test_name) if test_root is None: @@ -96,6 +185,13 @@ def retrieval_experiment(test_name=None, test_root: str = None): def summary_experiment(test_name: str = None, test_root: str = None): + """ + Prints a summary of the experiment specified by test_name or test_root. + + Args: + test_name: The name of the test. + test_root: The path to the test root directory. + """ if test_root is None: if test_name is None: raise ValueError() @@ -117,7 +213,16 @@ def summary_experiment(test_name: str = None, test_root: str = None): print('-----------------------------------') -def format_experiment(exp: Experiment): +def format_experiment(exp: Experiment) -> Dict[str, Any]: + """ + Formats the Experiment object into a dictionary. + + Args: + exp: An Experiment object. + + Returns: + A dictionary of the Experiment properties, tags, paths, and execution arguments. + """ return { 'Properties': exp.properties, 'tags': exp.tags, @@ -126,7 +231,16 @@ def format_experiment(exp: Experiment): } -def list_all_metrics(metric_root=None): +def list_all_metrics(metric_root=None) -> Dict[str, List[str]]: + """ + Returns a dictionary of all metrics found under metric_root directory. + + Args: + metric_root: The root directory to search for metrics. Default is None, which uses the default metric root directory. + + Returns: + A dictionary of all metrics, where the keys are the metric names and the values are lists of corresponding metric files. + """ if metric_root is None: metric_root = metricroot() diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py index 9733c8a..c6633a5 100644 --- a/src/lumo/proc/config.py +++ b/src/lumo/proc/config.py @@ -39,6 +39,20 @@ def local_config_path(): return None +def local_public_config_path(): + """ + Returns the path to the local configuration file that can be shared and public. + + Returns: + str: The path to the local configuration file, if found. Otherwise, None. + """ + from lumo.utils.repository import git_dir + res = git_dir() + if res: + return os.path.join(res, ".lumorc.public.json") + return None + + def get_config(path, default): """ Reads the configuration file at the given path or creates it if it doesn't exist. @@ -77,11 +91,20 @@ def get_runtime_config(): Returns: dict: The merged runtime configuration. """ - glob_cfg = get_config(global_config_path(), GLOBAL_DEFAULT) - local_cfg = get_config(local_config_path(), {}) + # default cfg = GLOBAL_DEFAULT + + # global config (~/.lumorc.json) + glob_cfg = get_config(global_config_path(), GLOBAL_DEFAULT) cfg.update(glob_cfg) + + # local private config ({repo}/.lumorc.json) + local_cfg = get_config(local_config_path(), {}) cfg.update(local_cfg) + + # local public config ({repo}/.lumorc.public.json) + local_public_cfg = get_config(local_public_config_path(), {}) + cfg.update(local_public_cfg) return cfg diff --git a/src/lumo/proc/dependency.py b/src/lumo/proc/dependency.py index 4f85fda..ae0e632 100644 --- a/src/lumo/proc/dependency.py +++ b/src/lumo/proc/dependency.py @@ -1,23 +1,8 @@ import importlib -from accelerate import __version__ as accelerate_version -from fire import __version__ as fire_version -from joblib import __version__ as joblib_version -from psutil import __version__ as psutil_version - -from lumo import __version__ as lumo_version - __all__ = ['get_lock'] -class Version: - # lumo = lumo_version - joblib = joblib_version - fire = fire_version - psutil = psutil_version - accelerate = accelerate_version - - def get_lock(*others): """ Used to record the specific version of the run-time dependencies to ensure reproducibility. @@ -30,7 +15,6 @@ def get_lock(*others): """ res = {} - res.update({k: v for k, v in Version.__dict__.items() if not k.startswith('__')}) for lib in others: mod = importlib.import_module(lib) diff --git a/src/lumo/trainer/base.py b/src/lumo/trainer/base.py index 4e2e37e..bab315f 100644 --- a/src/lumo/trainer/base.py +++ b/src/lumo/trainer/base.py @@ -296,4 +296,5 @@ def generate_exp_name(cls) -> str: return "{}.{}".format(pre.lower(), exp_name.lower()) def on_trainer_exception(self, func: Callable, exception: BaseException): + """Called when an exception occurs during training.""" pass diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 4125513..03a19af 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -342,6 +342,7 @@ def on_end(self, source: Trainer, func, params: ParamsType, result, *args, **kwa class LoggerCallback(TrainCallback, InitialCallback): + """A callback for logging the training process.""" priority = 99999 def __init__(self, step_frequence=3, break_in=1000): @@ -385,13 +386,14 @@ def on_train_begin(self, trainer: Trainer, func, params: ParamsType, *args, **kw file=self.temp) def renew(self, stage): - """创建一个新的""" + """Renew when change stage(train/eval/test)""" self.cur_tqdm = inlinetqdm(total=self.stage[stage], position=0, leave=True, bar_format='{desc}{elapsed}<{remaining} ({percentage:3.0f}%){postfix}', file=self.temp) self.record = Record() def update(self, trainer: Trainer): + """Update""" self.c += 1 self.cur_tqdm.update() if self.c % self.step == 0: @@ -403,6 +405,7 @@ def update(self, trainer: Trainer): # trainer.logger.newline() def flush(self, trainer: Trainer): + """Flush""" self.c = 0 trainer.logger.inline(self.cur_tqdm) trainer.logger.newline() @@ -447,6 +450,7 @@ def format_interval(t: float): def format_train_epoch_time(self, n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it', unit_scale=False, rate=None, bar_format=None, postfix=None, unit_divisor=1000, initial=0, colour=None, **extra_kwargs): + """Format""" elapsed_str = self.format_interval(elapsed) remaining = (total - n) / rate if rate and total else 0 remaining_str = self.format_interval(remaining) if rate else '?' @@ -508,61 +512,6 @@ def on_test_step_end(self, trainer: Trainer, func, params: ParamsType, self.update(trainer) -class EpochCheckpoint(TrainCallback): - """ - 在 Trainer 训练过程中定时保存模型 - """ - only_main_process = True - - def __init__(self, per_epoch=50): - self.per_epoch = per_epoch - - def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Optional[Record], *args, - **kwargs): - meter = record.agg() - if trainer.eidx % self.per_epoch == 0 and trainer.eidx > 0: - trainer.save_checkpoint(meta_info=Meter.wrap_result(meter)) - - def __repr__(self) -> str: - return self._repr_by_val("per_epoch") - - -class GlobalStepCheckpoint(TrainCallback): - only_main_process = True - - def __init__(self, per_step=2500): - self.per = per_step - - def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: Meter, *args, **kwargs): - super().on_train_step_end(trainer, func, params, metric, *args, **kwargs) - if trainer.global_steps % self.per == 0 and trainer.global_steps > 0: - trainer.save_checkpoint(meta_info=Meter.wrap_result(metric)) - - -class KeyErrorSave(TrainCallback): - """ - Callback to save checkpoints when you interrupt the program. - """ - only_main_process = True - only_single_gpu = True - priority = -1 - - def __init__(self, wait_input=False): - self.wait_input = wait_input - - def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseException, *args, **kwargs): - if isinstance(e, KeyboardInterrupt): - source.logger.info("KeyErrorSave trigged, save checkpoint") - source.save_checkpoint({"mode": "KeyboardInterrupt"}) - - tp = "n" - if self.wait_input: - tp = input("continue train step? (y/other)") - - if tp.lower() == "y": - return True - - class EMAUpdate(TrainCallback): """ Callback to update EMA model every train step. From feff99d32c15644105da0a235c34cd91b0e74d0e Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 20:09:13 +0800 Subject: [PATCH 54/99] ignore callbacks and exphoos --- .docstr.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.docstr.yaml b/.docstr.yaml index 9d40de3d..2fdfcc9 100644 --- a/.docstr.yaml +++ b/.docstr.yaml @@ -16,4 +16,6 @@ ignore_names_file: .*/test # regex #percentage_only: True # Boolean ignore_patterns: callbacks: + - ".*" + exphook: - ".*" \ No newline at end of file From 533e04e874f70dc4e2c20f49cdcc1d6feb71975f Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 20:09:23 +0800 Subject: [PATCH 55/99] Add .lumorc.public.json in lumo repo --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 397f93f..393a6ab 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ lumo_temp wandb .lumorc.json -docs/ \ No newline at end of file +.lumorc.public.json +docs/ From eda619c8cbe093dcf4854b859750f4b94f988bdc Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 20:09:37 +0800 Subject: [PATCH 56/99] Update docstring coverage rate --- images/docstr_coverage_badge.svg | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/images/docstr_coverage_badge.svg b/images/docstr_coverage_badge.svg index d8ec4de..ede1957 100644 --- a/images/docstr_coverage_badge.svg +++ b/images/docstr_coverage_badge.svg @@ -8,13 +8,13 @@ - + docstr-coverage docstr-coverage - 65% - 65% + 100% + 100% \ No newline at end of file From 695cc96ef710cc87fac083b7a64b8eb79c82346a Mon Sep 17 00:00:00 2001 From: sailist Date: Wed, 8 Mar 2023 20:09:48 +0800 Subject: [PATCH 57/99] Remove useless example --- examples/more/screen_str.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 examples/more/screen_str.py diff --git a/examples/more/screen_str.py b/examples/more/screen_str.py deleted file mode 100644 index e3f09e2..0000000 --- a/examples/more/screen_str.py +++ /dev/null @@ -1,20 +0,0 @@ -""" - -""" - -import sys -sys.path.insert(0,"../../") -import time - -from lumo.utils import ScreenStr -s = "\rLong Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text;Long Text" -print(ScreenStr(s,leftoffset=10),end="") -for i in range(100): - time.sleep(0.2) - - -from lumo.contrib.data import AutoCollate -from torch.utils.data.dataloader import DataLoader -import torch -device = torch.device('cuda:1') -DataLoader(...,collate_fn=AutoCollate(device)) From 9f076f2ec2b04c1d3f6279762555b39a51543746 Mon Sep 17 00:00:00 2001 From: sailist Date: Thu, 9 Mar 2023 00:22:02 +0800 Subject: [PATCH 58/99] Update git_dir retrieval manner --- src/lumo/utils/repository.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index 62470c1..73cd33d 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -275,18 +275,19 @@ def git_enable(): def git_dir(root='./'): """ git repository directory - git rev-parse --git-dir + git rev-parse --show-toplevel Args: root: Returns: + The original command, `git rev-parse --git-dir`, can not find a right path when the repository is a submodule inside another repository. """ if git_enable(): from git import Git cur = os.getcwd() os.chdir(root) - res = Git().execute(['git', 'rev-parse', '--git-dir']) - res = os.path.abspath(os.path.dirname(res)) + res = Git().execute(['git', 'rev-parse', '--show-toplevel']) + res = os.path.abspath(res) os.chdir(cur) return res else: From 867da42e9e2865f7f002e8ae1f4b91a9114b71af Mon Sep 17 00:00:00 2001 From: sailist Date: Thu, 9 Mar 2023 00:22:19 +0800 Subject: [PATCH 59/99] Add method to run command --- src/lumo/utils/subprocess.py | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 src/lumo/utils/subprocess.py diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py new file mode 100644 index 0000000..d7e8f5a --- /dev/null +++ b/src/lumo/utils/subprocess.py @@ -0,0 +1,53 @@ +import os +import subprocess +import select +import signal + + +def run_command(command, cwd=None): + proc = subprocess.Popen(command, cwd=cwd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + try: + while proc.poll() is None: + # Wait for output from the process + rlist, _, _ = select.select([proc.stdout, proc.stderr], [], [], 0.1) + for stream in rlist: + line = stream.readline().decode('utf-8') + if line: + print(line, end='') + + # Read the remaining output + for stream in [proc.stdout, proc.stderr]: + while True: + line = stream.readline().decode('utf-8') + if not line: + break + print(line, end='') + + # Get the return code of the process + return_code = proc.wait() + + # Raise an exception if the process returned a non-zero return code + # if return_code != 0: + # raise subprocess.CalledProcessError(return_code, command) + except KeyboardInterrupt: + os.kill(proc.pid, signal.SIGINT) + + while proc.poll() is None: + # Wait for output from the process + rlist, _, _ = select.select([proc.stdout, proc.stderr], [], [], 0.1) + for stream in rlist: + line = stream.readline().decode('utf-8').strip() + if line: + print(line) + + # Read the remaining output + for stream in [proc.stdout, proc.stderr]: + while True: + line = stream.readline().decode('utf-8').strip() + if not line: + break + print(line) + + # Get the return code of the process + return_code = proc.wait() + return return_code From 2c12469e8ceee8976d5927f2196d7c2ab52a070b Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 16:57:53 +0800 Subject: [PATCH 60/99] Re-add command line tools for lumo --- src/lumo/cli/__init__.py | 56 ++++++++++++++++++++++++++++++++++++++++ src/lumo/cli/__main__.py | 5 ++++ 2 files changed, 61 insertions(+) create mode 100644 src/lumo/cli/__init__.py create mode 100644 src/lumo/cli/__main__.py diff --git a/src/lumo/cli/__init__.py b/src/lumo/cli/__init__.py new file mode 100644 index 0000000..714545e --- /dev/null +++ b/src/lumo/cli/__init__.py @@ -0,0 +1,56 @@ +import fire + + +def rerun(test_name, **kwarg): + """ + rerun a test + lumo rerun + lumo rerun --device=0 + + Args: + test_name: + + Returns: + + """ + from lumo.exp.finder import retrieval_experiment + exp = retrieval_experiment(test_name) + if exp is not None: + exp.rerun([f'--{k}={v}' for k, v in kwarg.items()]) + else: + exit(1) + + +def note(test_name, description): + """ + Add note to a test: + lumo note description ; + + Args: + test_name: + description: + + Returns: + + """ + print(f"Adding note '{description}' to {test_name}") + + +def server(port=8080): + """ + + Args: + port: + + Returns: + + """ + print(f"Starting server on port {port}") + + +def main(): + fire.Fire({ + 'rerun': rerun, + 'note': note, + 'server': server, + }) diff --git a/src/lumo/cli/__main__.py b/src/lumo/cli/__main__.py new file mode 100644 index 0000000..b678b23 --- /dev/null +++ b/src/lumo/cli/__main__.py @@ -0,0 +1,5 @@ +import fire +from lumo.cli import main + +if __name__ == '__main__': + main() From f917488c47dde53abf26ae1797e98e69e04f64f8 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 16:58:12 +0800 Subject: [PATCH 61/99] Add persistent flag --- src/lumo/core/disk.py | 66 +++++++++++++------------------------------ 1 file changed, 19 insertions(+), 47 deletions(-) diff --git a/src/lumo/core/disk.py b/src/lumo/core/disk.py index 15b2535..baf1942 100644 --- a/src/lumo/core/disk.py +++ b/src/lumo/core/disk.py @@ -1,9 +1,11 @@ import os.path +import warnings from dbrecord import PList from lumo.proc import path from lumo.utils import safe_io as IO +from lumo.decorators.deprecated import DeprecatedWarning class Metrics: @@ -28,10 +30,11 @@ class Metrics: Writes any changes to the metric board SQLite file to disk. """ - def __init__(self, test_path: str): + def __init__(self, test_path: str, persistent=True): os.makedirs(test_path, exist_ok=True) self.fpath = os.path.join(test_path, f'metric_board.sqlite') self.disk = PList(self.fpath) + self.persistent = persistent def append(self, metric: dict, step, stage='train'): """ @@ -59,7 +62,8 @@ def flush(self): Returns: None """ - self.disk.flush() + if self.persistent: + self.disk.flush() class TableRow: @@ -77,6 +81,7 @@ class TableRow: - fpath (str): path of the file that stores the serialized row. - key (str): unique identifier of the row. - value (dict): dictionary representing the row. + - persistent (bool): whether to store in disk. Methods: - __enter__(self): context manager method. Does nothing. @@ -92,12 +97,12 @@ class TableRow: - __getitem__(self, item): returns the value of a key in the row. """ - def __init__(self, table, partition, rowkey): - dirpath = os.path.join(path.metricroot(), table) - os.makedirs(dirpath, exist_ok=True) - self.fpath = os.path.join(dirpath, partition, f'{rowkey}.pkl') - self.key = rowkey + def __init__(self, fn, persistent=True): + os.makedirs(os.path.dirname(os.path.abspath(fn)), exist_ok=True) + self.fpath = fn self.value = {} + self.persistent = persistent + # self.disk = PDict(self.fpath) def __enter__(self): @@ -114,7 +119,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): def flush(self): """Writes the value of the row to a file.""" - IO.dump_pkl(self.value, self.fpath) + if self.persistent: + IO.dump_pkl(self.value, self.fpath) def update_metrics(self, dic: dict, compare=None, flush=False): """ @@ -227,45 +233,11 @@ def update_metric_pair(self, key, value, key2, value2, compare=None, flush=False return {key: old, key2: old2} - def set_params(self, params: dict): - """ - Set the parameters dictionary of the row. - - Args: - params (dict): The parameters dictionary to set. - - Returns: - dict: The parameters dictionary set. - """ - self.value['params'] = params - self.flush() - - def update_dict(self, dic: dict, flush=False): - """ - Update the row with a dictionary. - - Args: - dic (dict): The dictionary to update the row with. - flush (bool, optional): Whether to flush to disk after updating. Default is False. - """ - for k, v in dic.items(): - self.update(k, v) - if flush: - self.flush() - - def update(self, key, value, flush=True): - """ - Update a key-value pair in the row. - - Args: - key (str): The key of the metric to update. - value (float): The value to set the metric to. - flush (bool, optional): Whether to flush to disk after updating. Default is True. - """ - self.value[key] = value - if flush: - self.flush() - def __getitem__(self, item): """Get the value of a key in the row.""" return self.value[item] + + +DeprecatedWarning(TableRow, '0.15.0', '1.0.0', + 'This class is deprecated and will be remove in 1.0.0, ' + 'Please use Experiment.metric to record your best metric.') From eee2fd6904bf2fb466206b28a3fbaa4e55c0df7f Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 16:58:47 +0800 Subject: [PATCH 62/99] Add debounce decorator --- src/lumo/decorators/debounce.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/lumo/decorators/debounce.py diff --git a/src/lumo/decorators/debounce.py b/src/lumo/decorators/debounce.py new file mode 100644 index 0000000..ed95142 --- /dev/null +++ b/src/lumo/decorators/debounce.py @@ -0,0 +1,38 @@ +import time +from functools import wraps + + +def debounce(wait): + """ + debounce + + Args: + wait (float): seconds + + Returns: + function: + + Examples: + @debounce(5) + def my_func(): + print("my_func called") + + while True: + my_func() + time.sleep(1) + """ + + def decorator(func): + last_time = 0 + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal last_time + now = time.time() + if now - last_time > wait: + last_time = now + return func(*args, **kwargs) + + return wrapper + + return decorator From 26d650916934dfd1d2c9b9ff4111763af33103e4 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 16:59:39 +0800 Subject: [PATCH 63/99] Unify the functions of Experiment: pathhelper, record info and metric, and ensure reproducibility. --- src/lumo/exp/experiment.py | 908 ++++++++++++++++++++----------------- src/lumo/exp/finder.py | 2 +- src/lumo/exp/metric.py | 60 +++ src/lumo/exp/watch.py | 61 +++ 4 files changed, 619 insertions(+), 412 deletions(-) create mode 100644 src/lumo/exp/metric.py create mode 100644 src/lumo/exp/watch.py diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 7dee762..00c1d1d 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -1,21 +1,28 @@ +""" +Experiment 负责的内容 + - 管理路径 PathHelper + - 记录信息 InfoIO 和度量 Metric + - 快照 snap 和复现 rerun +""" import os import random import sys import time import traceback from pathlib import Path -from typing import Union, Any - +from typing import Union, Any, List +from functools import wraps from lumo.decorators.process import call_on_main_process_wrap from lumo.proc import glob from lumo.proc.dist import is_dist, is_main, local_rank from lumo.proc.path import blobroot, libhome, progressroot from lumo.proc.path import exproot, local_dir from lumo.utils import safe_io as io -from lumo.utils.fmt import can_be_filename +from lumo.utils.fmt import can_be_filename, strftime from lumo.utils.logger import Logger from .base import BaseExpHook from ..proc.pid import pid_hash, runtime_pid_obj +from .metric import Metric def checkdir(path: Union[Path, str]): @@ -49,12 +56,19 @@ class Experiment: - experiments # (exp_root) record information (e.g., .log, params files, etc.) - {experiment-name-1} - {test-1} - # infomation - { - progress - pid_hash (for lumo.client monitor) - other_info: git, file, version_lock, etc. - } + metric_board.sqlite (metrics in training ,powered by dbrecord) + metric.pkl (final metrics) + params.yaml (hyper parameter) + note.md (manually note) + l.0.2303062216.log (log file) + text/ + exception.str + ... + info/ + *.json + git.json + execute.json + lock.json - {test-2} - {experiment-name-2} - {test-1} @@ -66,16 +80,20 @@ class Experiment: - {experiment-name-2} - {test-1} - {test-2} - - metric # (metric_root) record metrics (by trainer.database) - - {experiment-name-1} - - {test-1} - - {test-2} - - {experiment-name-2} - - {test-1} - - {test-2} + {lumo.cache_dir} + - progress # (metric_root) record metrics (by trainer.database) + - hb # trigger for test information update + {experiment-name} + - {test-1} -> timestamp + - {test-2} -> timestamp + - pid # link to running process + - {pid1} -> test_root + - {pid2} -> test_root """ - def __init__(self, exp_name: str, root=None): + ENV_TEST_NAME_KEY = 'LUMO_EXP_TEST_NAME' + + def __init__(self, exp_name: str, root=None, test_name=None): """ Initializes a new instance of the Experiment class. @@ -94,36 +112,102 @@ def __init__(self, exp_name: str, root=None): self._prop = {} self._prop['exp_name'] = exp_name + self._prop['test_name'] = test_name self._hooks = {} + + self._metric = Metric(self.metrics_fn) + + # wrap + self._metric.dump_metrics = self._trigger_change(self._metric.dump_metrics) + self._metric.dump_metric = self._trigger_change(self._metric.dump_metric) + self.dump_string = self._trigger_change(self.dump_string) + self.dump_note = self._trigger_change(self.dump_note) + self.dump_info = self._trigger_change(self.dump_info) + if root is None: root = libhome() self._root = Path(os.path.abspath(root)) self.add_exit_hook(self.end) self.logger = Logger() - @property - def exp_name(self): + def __getitem__(self, item): """ - str: Gets the name of the experiment. + Gets a property of the experiment. + + Args: + item (str): The name of the property to get. + + Returns: + Any: The value of the property. """ - return self._prop['exp_name'] + return self._prop[item] - @property - def _test_name(self): + def __setitem__(self, key, value): """ - str: Gets the name of the current test being run. + Sets a property of the experiment. + + Args: + key (str): The name of the property to set. + value (Any): The value to set the property to. """ - return self._prop.get('test_name', None) + self._prop[key] = value - @_test_name.setter - def _test_name(self, value): + def __enter__(self): """ - Sets the name of the current test being run. + Starts the experiment when the Experiment object is used as a context manager using the 'with' statement. - Args: - value (str): The name of the current test. + Returns: + Experiment: The Experiment object. """ - self._prop['test_name'] = value + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Ends the experiment when the 'with' statement block exits. + + Args: + exc_type (type): The type of the exception that occurred, if any. + exc_val (Exception): The exception object that was raised, if any. + exc_tb (traceback): The traceback object for the exception, if any. + """ + extra = {} + if exc_type is not None: + exc_type = traceback.format_exception_only(exc_type, exc_val)[-1].strip() + extra['exc_type'] = exc_type + extra['end_info'] = str(exc_type) + extra['end_code'] = 1 + extra['exc_stack'] = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)) + self.end(**extra) + + def __repr__(self): + """ + Returns a string representation of the Experiment object. + + Returns: + str: A string representation of the Experiment object. + """ + return f'{self.exp_name}->({self.test_name})' + + def __str__(self): + """ + Returns a string representation of the Experiment object. + + Returns: + str: A string representation of the Experiment object. + """ + return self.__repr__() + + def _repr_html_(self): + """Return a html representation for a particular DataFrame.""" + return self.__repr__() + + @property + def exp_name(self): + """ + str: Gets the name of the experiment. + """ + return self._prop['exp_name'] @property def test_name_with_dist(self): @@ -170,21 +254,23 @@ def test_name(self): return self._test_name - def _create_test_name(self): + @property + def _test_name(self): """ - Generates a unique test name based on the current date and time. - regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t + str: Gets the name of the current test being run. + """ + given_test_name = os.environ.get(Experiment.ENV_TEST_NAME_KEY, None) + return self._prop.get('test_name', given_test_name) - Returns: - str: The generated test name. + @_test_name.setter + def _test_name(self, value): """ - from lumo.proc.date import timehash - from ..utils.fmt import strftime - fs = os.listdir(self.exp_root) - date_str = strftime('%y%m%d') - fs = [i for i in fs if i.startswith(date_str)] - _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t" - return _test_name + Sets the name of the current test being run. + + Args: + value (str): The name of the current test. + """ + self._prop['test_name'] = value @property def root_branch(self): @@ -233,475 +319,507 @@ def test_branch(self): val = self.exp_branch.joinpath(self.test_name) return checkdir(val) - def dump_progress(self, ratio: float, update_from=None): + @property + def tags(self): """ - Saves progress information about the experiment. - - Args: - ratio (float): The progress ratio as a number between 0 and 1. - update_from: The process from which the progress update came from. + dict: Gets the tags associated with the experiment. """ - res = {'ratio': max(min(ratio, 1), 0)} - if update_from is None: - res['update_from'] = update_from - self.dump_info('progress', res, append=True) + tags = {} + for path in self.test_branch.joinpath('tags').glob('tag.*.json'): + ptags = io.load_json(path.as_posix()) # type: dict + tags.setdefault(path.suffixes[0].strip('.'), []).extend(ptags.keys()) + return tags - def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop=True): + @property + def repo_name(self): """ - Saves information about the experiment to a file. + Gets the name of the repository associated with the experiment. - Args: - key (str): The key under which the information will be stored. - info (Any): The information to store. - append (bool, optional): Whether to append to the file or overwrite it. Defaults to False. - info_dir (str, optional): The name of the directory where the file will be stored. Defaults to 'info'. - set_prop (bool, optional): Whether to set the experiment property with the same key to the saved information. - Defaults to True. + Returns: + str: The name of the repository. """ - fn = self.test_file(f'{key}.json', info_dir) - if append: - old_info = self.load_info(key, info_dir=info_dir) - old_info.update(info) - info = old_info - if set_prop: - self.set_prop(key, info) - io.dump_json(info, fn) + return self.project_name - def load_info(self, key: str, info_dir='info'): + @property + def project_name(self): """ - Loads information about the experiment from a file. - - Args: - key (str): The key under which the information is stored. - info_dir (str, optional): The name of the directory where the file is stored. Defaults to 'info'. + Gets the name of the project associated with the experiment. Returns: - Any: The information stored under the specified key. + str: The name of the project. """ - fn = self.test_file(f'{key}.json', info_dir) - if not os.path.exists(fn): - return {} - return io.load_json(fn) + return os.path.basename(self.project_root) - def dump_string(self, key: str, info: str): + @property + def project_root(self): """ - Saves a string to a file. + Gets the path to the root directory of the project associated with the experiment. - Args: - key (str): The key under which the string will be stored. - info (str): The string to store. + Returns: + str: The path to the root directory of the project. """ - fn = self.test_file(f'{key}.str', 'text') - io.dump_text(info, fn) - self.set_prop(key, info) + return local_dir() - def load_string(self, key: str): + @property + def exp_root(self): """ - Loads a string from a file. - - Args: - key (str): The key under which the string is stored. + Gets the path to the directory containing the tests for the experiment. Returns: - str: The string stored under the specified key. + str: The path to the experiment root directory. """ - fn = self.test_file(f'{key}.str', 'text') - if not os.path.exists(fn): - return '' - return io.load_text(fn) + return self.exp_branch.as_posix() @property - def tags(self): - """ - dict: Gets the tags associated with the experiment. + def test_root(self): """ - tags = {} - for path in self.test_branch.joinpath('tags').glob('tag.*.json'): - ptags = io.load_json(path.as_posix()) # type: dict - tags.setdefault(path.suffixes[0].strip('.'), []).extend(ptags.keys()) - return tags + Gets the path to the directory containing information about the current test. - def add_tag(self, tag: str, name_space: str = 'default'): - """ - Adds a tag to the experiment. - Args: - tag (str): The tag to add. - name_space (str, optional): The namespace under which to - add the tag. Defaults to 'default'. + Returns: + str: The path to the test root directory. """ - self.dump_info(f'tag.{name_space}', { - tag: None - }, append=True, info_dir='tags', set_prop=False) + return self.test_branch.as_posix() - def exp_file(self, filename, *args): + @property + def blob_root(self): """ - Gets the path to a file in the experiment directory. - - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + Gets the path to the directory containing large binary files associated with the experiment. Returns: - str: The path to the specified file. + str: The path to the blob root directory. """ - parent = self.exp_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + return self.blob_branch.as_posix() - def test_file(self, filename, *args): + @property + def properties(self): """ - Gets the path to a file in the test directory. - - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + Gets a dictionary containing all properties of the experiment. Returns: - str: The path to the specified file. + dict: A dictionary containing all properties of the experiment. """ - parent = self.test_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + return self._prop - def exp_dir(self, *args): - """ - Gets the path to a directory in the experiment directory. + @property + def metrics_fn(self): + return self.test_file('metric.pkl') - Args: - *args: Any subdirectory names to include in the directory path. + @property + def metric(self): + """ + Gets a dictionary containing all metrics of the experiment. Returns: - str: The path to the specified directory. + Metric: A dictionary containing all metrics of the experiment. """ - parent = self.exp_branch.joinpath(*args) - return checkdir(parent).as_posix() + return self._metric - def root_file(self, filename, *args): - """ - Gets the path to a file in the library's root directory. + @property + def note(self): + fn = self.test_file('note.md') + if os.path.exists(fn): + return io.load_text(fn) + return fn - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + @property + def paths(self) -> dict: + """ + Gets a dictionary containing the paths to various directories associated with the experiment. Returns: - str: The path to the specified file. + dict: A dictionary containing the paths to various directories associated with the experiment. """ - parent = self.root_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + return { + 'root': self.root_branch.as_posix(), + 'exp_root': self.exp_root, + 'test_root': self.test_root, + 'blob_root': self.blob_root, + } - def root_dir(self, *args): + @property + def is_alive(self): """ - Gets the path to a directory in the library's root directory. - - Args: - *args: Any subdirectory names to include in the directory path. + Determines whether the process associated with the experiment is still running. Returns: - str: The path to the specified directory. + bool: True if the process is still running, False otherwise. """ - parent = self.root_branch.joinpath(*args) - return checkdir(parent).as_posix() + pinfo = self.properties['pinfo'] - def test_dir(self, *args): - """ - Gets the path to a directory in the test directory. + hash_obj = runtime_pid_obj(pinfo['pid']) + if hash_obj is None: + return False - Args: - *args: Any subdirectory names to include in the directory path. + return pid_hash(hash_obj) == pinfo['hash'] - Returns: - str: The path to the specified directory. + @property + def exec_argv(self): """ - parent = self.test_branch.joinpath(*args) - return checkdir(parent).as_posix() + Gets the arguments used to execute the script associated with the experiment. - def blob_file(self, filename, *args): + Returns: + List[str]: A list of arguments used to execute the script. """ - Gets the path to a file in the blob directory. + execute_info = self.properties.get('execute') + try: + return [os.path.basename(execute_info['exec_bin']), *execute_info['exec_argv']] + except: + return [] - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + def _trigger_change(self, func): + # test_root update some files + @wraps(func) + def inner(*args, **kwargs): + fn = self.progress_file(f'{self.test_name}.heartbeat', 'hb', self.exp_name) + io.dump_text(strftime(), fn) + func(*args, **kwargs) + + return inner + + def _create_test_name(self): + """ + Generates a unique test name based on the current date and time. + regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t Returns: - str: The path to the specified file. + str: The generated test name. """ - parent = self.blob_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + from lumo.proc.date import timehash + from ..utils.fmt import strftime + fs = os.listdir(self.exp_root) + date_str = strftime('%y%m%d') + fs = [i for i in fs if i.startswith(date_str)] + _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t" + return _test_name - def progress_file(self, filename): + def get_prop(self, key, default=None): """ - Gets the path to a file in the progress directory. + Gets the value of a property of the experiment. Args: - filename (str): The name of the file. + key (str): The name of the property to get. + default (Any, optional): The default value to return if the property does not exist. Defaults to None. Returns: - str: The path to the specified file. + Any: The value of the property, or the default value if the property does not exist. """ - return self.progress_branch.joinpath(filename).as_posix() + return self._prop.get(key, default) - def blob_dir(self, *args): + def has_prop(self, key): """ - Gets the path to a directory in the blob directory. + Determines whether the experiment has a certain property. + Args: - *args: Any subdirectory names to include in the directory path. + key (str): The name of the property to check for. Returns: - str: The path to the specified directory. + bool: True if the experiment has the property, False otherwise. """ - parent = self.blob_branch.joinpath(*args) - return checkdir(parent).as_posix() + return key in self._prop - def __enter__(self): + def set_prop(self, key, value): """ - Starts the experiment when the Experiment object is used as a context manager using the 'with' statement. + Sets a property of the experiment. - Returns: - Experiment: The Experiment object. + Args: + key (str): The name of the property to set. + value (Any): The value to set the property to. """ - self.start() - return self + self._prop[key] = value - def __exit__(self, exc_type, exc_val, exc_tb): + def dump_progress(self, ratio: float, update_from=None): """ - Ends the experiment when the 'with' statement block exits. + Saves progress information about the experiment. Args: - exc_type (type): The type of the exception that occurred, if any. - exc_val (Exception): The exception object that was raised, if any. - exc_tb (traceback): The traceback object for the exception, if any. + ratio (float): The progress ratio as a number between 0 and 1. + update_from: The process from which the progress update came from. """ - extra = {} - if exc_type is not None: - exc_type = traceback.format_exception_only(exc_type, exc_val)[-1].strip() - extra['exc_type'] = exc_type - extra['end_info'] = str(exc_type) - extra['end_code'] = 1 - extra['exc_stack'] = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)) - self.end(**extra) + res = {'ratio': max(min(ratio, 1), 0)} + if update_from is None: + res['update_from'] = update_from + res['last_edit_time'] = strftime() + self.dump_info('progress', res, append=True) - @call_on_main_process_wrap - def add_exit_hook(self, func): + def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop=True): """ - Registers a function to be called when the program exits. + Saves information about the experiment to a file. Args: - func (callable): The function to register. + key (str): The key under which the information will be stored. + info (Any): The information to store. + append (bool, optional): Whether to append to the file or overwrite it. Defaults to False. + info_dir (str, optional): The name of the directory where the file will be stored. Defaults to 'info'. + set_prop (bool, optional): Whether to set the experiment property with the same key to the saved information. + Defaults to True. """ - import atexit - def exp_func(): - """Function executed before process exit.""" - func(self) - - atexit.register(exp_func) + fn = self.test_file(f'{key}.json', info_dir) + if append: + old_info = self.load_info(key, info_dir=info_dir) + old_info.update(info) + info = old_info + if set_prop: + self[key] = info + # self.set_prop(key, info) + io.dump_json(info, fn) - @call_on_main_process_wrap - def initial(self): - """ - Initializes the experiment by setting up progress, information, and PID tracking. + def load_info(self, key: str, info_dir='info'): """ - self.add_tag(self.__class__.__name__, 'exp_type') - self.dump_progress(0) - self.dump_info('execute', { - 'repo': self.project_root, - 'cwd': os.getcwd(), - 'exec_file': sys.argv[0], - 'exec_bin': sys.executable, - 'exec_argv': sys.argv - }) - self.dump_info('pinfo', { - 'pid': os.getpid(), - 'hash': pid_hash(), - 'obj': runtime_pid_obj(), - }) + Loads information about the experiment from a file. - # register progress - io.dump_text(self.test_root, self.progress_file(f'{os.getpid()}')) + Args: + key (str): The key under which the information is stored. + info_dir (str, optional): The name of the directory where the file is stored. Defaults to 'info'. - @call_on_main_process_wrap - def start(self): - """ - Starts the experiment. + Returns: + Any: The information stored under the specified key. """ - if self.get_prop('start', False): - return - self.initial() - self.dump_info('start', True) - for hook in self._hooks.values(): # type: BaseExpHook - hook.on_start(self) - return self + fn = self.test_file(f'{key}.json', info_dir) + if not os.path.exists(fn): + return {} + return io.load_json(fn) - @call_on_main_process_wrap - def end(self, end_code=0, *args, **extra): + def dump_note(self, note: str): + fn = self.test_file('note.md') + io.dump_text(note, fn) + + def dump_string(self, key: str, info: str, append=False): """ - Ends the experiment. + Saves a string to a file. Args: - end_code (int): The exit code to set for the experiment. - *args: Additional arguments to pass to the end hooks. - **extra: Additional keyword arguments to pass to the end hooks. + key (str): The key under which the string will be stored. + info (str): The string to store. """ - if not self.get_prop('start', False): - return - if self.get_prop('end', False): - return - self.dump_progress(1) - self.dump_info('end', True) - for hook in self._hooks.values(): # type: BaseExpHook - hook.on_end(self, end_code=end_code, *args, **extra) - return self + fn = self.test_file(f'{key}.str', 'text') + io.dump_text(info, fn, append=append) + if not append: + self.set_prop(key, info) - @property - def repo_name(self): + def load_string(self, key: str): """ - Gets the name of the repository associated with the experiment. + Loads a string from a file. + + Args: + key (str): The key under which the string is stored. Returns: - str: The name of the repository. + str: The string stored under the specified key. """ - return self.project_name + fn = self.test_file(f'{key}.str', 'text') + if not os.path.exists(fn): + return '' + return io.load_text(fn) - @property - def project_name(self): - """ - Gets the name of the project associated with the experiment. + def dump_metric(self, key, value, cmp: str, flush=True, **kwargs): + return self.metric.dump_metric(key, value, cmp, flush, **kwargs) - Returns: - str: The name of the project. - """ - return os.path.basename(self.project_root) + def dump_metrics(self, dic: dict, cmp: str): + return self.metric.dump_metrics(dic, cmp) - @property - def project_root(self): + def exp_dir(self, *args): """ - Gets the path to the root directory of the project associated with the experiment. + Gets the path to a directory in the experiment directory. + + Args: + *args: Any subdirectory names to include in the directory path. Returns: - str: The path to the root directory of the project. + str: The path to the specified directory. """ - return local_dir() + parent = self.exp_branch.joinpath(*args) + return checkdir(parent).as_posix() - @property - def exp_root(self): + def exp_file(self, filename, *args): """ - Gets the path to the directory containing the tests for the experiment. + Gets the path to a file in the experiment directory. + + Args: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. Returns: - str: The path to the experiment root directory. + str: The path to the specified file. """ - return self.exp_branch.as_posix() + parent = self.exp_branch.joinpath(*args) + return checkdir(parent).joinpath(filename).as_posix() - @property - def test_root(self): + def test_dir(self, *args): """ - Gets the path to the directory containing information about the current test. + Gets the path to a directory in the test directory. + + Args: + *args: Any subdirectory names to include in the directory path. Returns: - str: The path to the test root directory. + str: The path to the specified directory. """ - return self.test_branch.as_posix() + parent = self.test_branch.joinpath(*args) + return checkdir(parent).as_posix() - @property - def blob_root(self): + def test_file(self, filename, *args): """ - Gets the path to the directory containing large binary files associated with the experiment. + Gets the path to a file in the test directory. + + Args: + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. Returns: - str: The path to the blob root directory. + str: The path to the specified file. """ - return self.blob_branch.as_posix() + parent = self.test_branch.joinpath(*args) + return checkdir(parent).joinpath(filename).as_posix() - def __getitem__(self, item): + def root_dir(self, *args): """ - Gets a property of the experiment. + Gets the path to a directory in the library's root directory. Args: - item (str): The name of the property to get. + *args: Any subdirectory names to include in the directory path. Returns: - Any: The value of the property. + str: The path to the specified directory. """ - return self._prop[item] + parent = self.root_branch.joinpath(*args) + return checkdir(parent).as_posix() - def __setitem__(self, key, value): + def root_file(self, filename, *args): """ - Sets a property of the experiment. + Gets the path to a file in the library's root directory. Args: - key (str): The name of the property to set. - value (Any): The value to set the property to. - """ - self._prop[key] = value + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. - def get_prop(self, key, default=None): + Returns: + str: The path to the specified file. """ - Gets the value of a property of the experiment. + parent = self.root_branch.joinpath(*args) + return checkdir(parent).joinpath(filename).as_posix() + def blob_dir(self, *args): + """ + Gets the path to a directory in the blob directory. Args: - key (str): The name of the property to get. - default (Any, optional): The default value to return if the property does not exist. Defaults to None. + *args: Any subdirectory names to include in the directory path. Returns: - Any: The value of the property, or the default value if the property does not exist. + str: The path to the specified directory. """ - return self._prop.get(key, default) + parent = self.blob_branch.joinpath(*args) + return checkdir(parent).as_posix() - def has_prop(self, key): + def blob_file(self, filename, *args): """ - Determines whether the experiment has a certain property. + Gets the path to a file in the blob directory. Args: - key (str): The name of the property to check for. + filename (str): The name of the file. + *args: Any additional subdirectory names to include in the file path. Returns: - bool: True if the experiment has the property, False otherwise. + str: The path to the specified file. """ - return key in self._prop + parent = self.blob_branch.joinpath(*args) + return checkdir(parent).joinpath(filename).as_posix() - def set_prop(self, key, value): + def progress_file(self, filename, *args): """ - Sets a property of the experiment. + Gets the path to a file in the progress directory. Args: - key (str): The name of the property to set. - value (Any): The value to set the property to. - """ - self._prop[key] = value - - @property - def properties(self): - """ - Gets a dictionary containing all properties of the experiment. + filename (str): The name of the file. Returns: - dict: A dictionary containing all properties of the experiment. + str: The path to the specified file. """ - return self._prop + parent = self.progress_branch.joinpath(*args) + return checkdir(parent).joinpath(filename).as_posix() + @classmethod @property - def paths(self) -> dict: + def Class(cls): + return cls + + def rerun(self, arg_list: List[str]): + """rerun this test in another""" + # self.properties[''] + new_test_name = self._create_test_name() + new_exp = Experiment(self.exp_name, root=self._root, test_name=new_test_name) + self.dump_info('deprecated', {'rerun_at': new_exp.test_name}) + old_rerun_info = self.properties.get('rerun', None) + count = 1 + if old_rerun_info is not None: + count += old_rerun_info['count'] + new_exp.dump_info('rerun', {'from': self.test_name, 'repeat': count}) + from lumo.utils.subprocess import run_command + old_exec = self.properties['execute'] + command = ' '.join([old_exec['exec_bin'], old_exec['exec_file'], *old_exec['exec_argv'], *arg_list]) + env = os.environ.copy() + env[Experiment.ENV_TEST_NAME_KEY] = new_exp.test_name + return run_command(command, cwd=old_exec['cwd']) + + @call_on_main_process_wrap + def initial(self): """ - Gets a dictionary containing the paths to various directories associated with the experiment. + Initializes the experiment by setting up progress, information, and PID tracking. + """ + self.dump_info('progress', {'start': strftime(), 'finished': False}, append=True) + self.dump_progress(0) + self.dump_info('execute', { + 'repo': self.project_root, + 'cwd': os.getcwd(), + 'exec_file': sys.argv[0], + 'exec_bin': sys.executable, + 'exec_argv': sys.argv + }) + self.dump_info('pinfo', { + 'pid': os.getpid(), + 'hash': pid_hash(), + 'obj': runtime_pid_obj(), + }) - Returns: - dict: A dictionary containing the paths to various directories associated with the experiment. + # register progress + # register this process + io.dump_text(self.test_root, self.progress_file(f'{os.getpid()}', 'pid')) + + @call_on_main_process_wrap + def start(self): """ - return { - 'root': self.root_branch.as_posix(), - 'exp_root': self.exp_root, - 'test_root': self.test_root, - 'blob_root': self.blob_root, - } + Starts the experiment. + """ + if self.properties.get('start', False): + return + self.initial() + self.set_prop('start', True) + for hook in self._hooks.values(): # type: BaseExpHook + hook.on_start(self) + return self - @property - def enable_properties(self) -> set: + @call_on_main_process_wrap + def end(self, end_code=0, *args, **extra): """ - Gets a set of the names of all properties that have been set for the experiment. + Ends the experiment. - Returns: - set: A set of the names of all properties that have been set for the experiment. + Args: + end_code (int): The exit code to set for the experiment. + *args: Additional arguments to pass to the end hooks. + **extra: Additional keyword arguments to pass to the end hooks. """ - return set(self._prop.keys()) + if not self.is_alive: + return + if not self.properties.get('start', False): + return + if self.properties.get('end', False): + return + self.set_prop('end', True) + self.dump_progress(1) + + self.dump_info('progress', {'end': strftime(), 'finished': end_code == 0}, append=True) + for hook in self._hooks.values(): # type: BaseExpHook + hook.on_end(self, end_code=end_code, *args, **extra) + return self @call_on_main_process_wrap def set_hook(self, hook: BaseExpHook): @@ -711,29 +829,34 @@ def set_hook(self, hook: BaseExpHook): Args: hook (BaseExpHook): The hook to register. """ - hook.regist(self) if not glob.get(hook.config_name, True): - self.dump_info(hook.name, { - 'code': -1, - 'msg': f'{hook.name} disabled' - }) + self.dump_info('hooks', { + hook.__class__.__name__: {'loaded': False, 'msg': 'disabled by config'} + }, append=True) + return self + else: + hook.regist(self) + self.dump_info('hooks', { + hook.__class__.__name__: {'loaded': True, 'msg': ''} + }, append=True) + self.logger.info(f'Register {hook}.') + self._hooks[hook.__class__.__name__] = hook return self - self.logger.info(f'Register {hook}.') - self._hooks[hook.__class__.__name__] = hook - self.add_tag(hook.__class__.__name__, 'hooks') - return self - def load_prop(self): + @call_on_main_process_wrap + def add_exit_hook(self, func): """ - Loads all properties associated with the experiment from disk. + Registers a function to be called when the program exits. + + Args: + func (callable): The function to register. """ - for f in os.listdir(self.test_dir('info')): - key = os.path.splitext(f)[0] - self.set_prop(key, self.load_info(key)) + import atexit + def exp_func(): + """Function executed before process exit.""" + func(self) - for f in os.listdir(self.test_dir('text')): - key = os.path.splitext(f)[0] - self.set_prop(key, self.load_string(key)) + atexit.register(exp_func) @classmethod def from_disk(cls, path): @@ -757,56 +880,19 @@ def from_disk(cls, path): root = test_root.parent.parent.parent.as_posix() self = cls(test_root.parent.name, root=root) self._test_name = test_root.name - self.load_prop() - return self - - @property - def is_alive(self): - """ - Determines whether the process associated with the experiment is still running. - - Returns: - bool: True if the process is still running, False otherwise. - """ - pinfo = self.properties['pinfo'] - - hash_obj = runtime_pid_obj(pinfo['pid']) - if hash_obj is None: - return False - - return pid_hash(hash_obj) == pinfo['hash'] - - @property - def exec_argv(self): - """ - Gets the arguments used to execute the script associated with the experiment. - Returns: - List[str]: A list of arguments used to execute the script. - """ - execute_info = self.get_prop('execute') - try: - return [os.path.basename(execute_info['exec_bin']), *execute_info['exec_argv']] - except: - return [] - - def __repr__(self): - """ - Returns a string representation of the Experiment object. - - Returns: - str: A string representation of the Experiment object. - """ - return f'{self.exp_name}->({self.test_name})' + # load prop + for f in os.listdir(self.test_dir('info')): + key = os.path.splitext(f)[0] + self.set_prop(key, self.load_info(key)) - def __str__(self): - """ - Returns a string representation of the Experiment object. + for f in os.listdir(self.test_dir('text')): + key = os.path.splitext(f)[0] + self.set_prop(key, self.load_string(key)) - Returns: - str: A string representation of the Experiment object. - """ - return self.__repr__() + # load metric + self._metric = Metric(self.metrics_fn) + return self class SimpleExperiment(Experiment): diff --git a/src/lumo/exp/finder.py b/src/lumo/exp/finder.py index 0f5c53f..5398917 100644 --- a/src/lumo/exp/finder.py +++ b/src/lumo/exp/finder.py @@ -161,7 +161,7 @@ def retrieval_test_root(test_flag: str) -> str: return test_root -def retrieval_experiment(test_name=None, test_root: str = None): +def retrieval_experiment(test_name=None, test_root: str = None) -> Experiment: """ Loads an Experiment object from disk for the given test name or test root. diff --git a/src/lumo/exp/metric.py b/src/lumo/exp/metric.py new file mode 100644 index 0000000..abf5f61 --- /dev/null +++ b/src/lumo/exp/metric.py @@ -0,0 +1,60 @@ +import os +from lumo.utils import safe_io as IO + + +class Metric: + """ + """ + + def __init__(self, metric_fn, persistent=True): + os.makedirs(os.path.dirname(os.path.abspath(metric_fn)), exist_ok=True) + self.fn = metric_fn + self._metric = {} + if os.path.exists(metric_fn): + self._metric = IO.load_pkl(metric_fn) + self.persistent = persistent + + @property + def value(self): + """ + A property that returns the metric values of the row. + + Returns: + dict: A dictionary containing the metric values of the row. + """ + return self._metric + + def dump_metric(self, key, value, cmp: str, flush=True, **kwargs): + dic = self.value + older = dic.setdefault(key, None) + + update = False + if older is None or cmp is None: + update = True + else: + if cmp == 'max': + if older < value: + update = True + elif cmp == 'min': + if older > value: + update = True + else: + raise NotImplementedError() + + if update: + dic[key] = value + for kk, vv in kwargs.items(): + dic[kk] = vv + + if flush: + self.flush() + return value + + def dump_metrics(self, dic: dict, cmp: str): + for k, v in dic.items(): + self.dump_metric(k, v, cmp) + + def flush(self): + """Writes the value of the row to a file.""" + if self.persistent: + IO.dump_pkl(self.value, self.fn) diff --git a/src/lumo/exp/watch.py b/src/lumo/exp/watch.py new file mode 100644 index 0000000..e07e2dd --- /dev/null +++ b/src/lumo/exp/watch.py @@ -0,0 +1,61 @@ +""" +Watcher 可以在运行实验后在 jupyter 或者网页上展示正在运行和已经运行结束的实验(按时间顺序?) +以及可以简化记录实验的烦恼 + +现在的核心痛点是 + - [ ] 所有元信息都有了,但是找不到哪个实验是哪个实验 + - [ ] 同时跑的多个实验有一个失败了,重跑时会混淆,或许需要一种覆盖手段 -> + - > 怎么 rerun? + lumo rerun test_name √ + lumo note html (用 streamlit 之类的生成动态网页) + lumo note cmd (类似 top 的视角,按时间顺序排列) +- > rerun 将已经跑的实验 move + +可以代替 analysis 的作用。主要有 + +-> 按照 progress 目录,获取所有的实验 +-> 根据获取的实验,按顺序记录 +-> 每次只记录 + +""" +import os.path + +from lumo.proc.path import progressroot + +PID_ROOT = os.path.join(progressroot(), 'pid') +HB_ROOT = os.path.join(progressroot(), 'hb') +EXP_ROOT = os.path.join(progressroot()) + + +# class Watcher: +# """List and watch experiments with time order +# +# Cache test_information in +# metrics/.sqlite +# """ +# +# def load(self): +# pass +# +# def interactive(self): +# """interactive, mark, label, note in ipython environment.""" +# pass +# +# def server(self): +# """simple server which make you note your experiments""" +# pass +# +# def list_all(self, exp_root=None, limit=100) -> Dict[str, List[Experiment]]: +# """ +# Returns a dictionary of all experiments under exp_root directory. +# +# Args: +# exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory. +# +# Returns: +# A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects. +# """ +# return { +# _get_exp_name(exp_path): retrieval_tests_from_experiment(exp_path) +# for exp_path in list_experiment_paths(exp_root) +# } From c330f761dde686756de693a42e5f29cefce8285d Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 16:59:51 +0800 Subject: [PATCH 64/99] Progress dir --- src/lumo/proc/path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/proc/path.py b/src/lumo/proc/path.py index e90618b..2a7bf18 100644 --- a/src/lumo/proc/path.py +++ b/src/lumo/proc/path.py @@ -76,7 +76,7 @@ def progressroot(): if PROGRESS_ROOT: res = PROGRESS_ROOT else: - res = os.path.join(libhome(), 'progress') + res = os.path.join(cache_dir(), 'progress') os.makedirs(res, exist_ok=True) return res From b43ee01a1213000164588df0a1c8d74770893e64 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:00:02 +0800 Subject: [PATCH 65/99] Some updates --- src/lumo/sketch/vis/__main__.py | 12 ++++++------ src/lumo/sketch/vis/parser.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lumo/sketch/vis/__main__.py b/src/lumo/sketch/vis/__main__.py index d3f6069..526c02b 100644 --- a/src/lumo/sketch/vis/__main__.py +++ b/src/lumo/sketch/vis/__main__.py @@ -34,10 +34,10 @@ def make_test(self, test_root: str): st.write(finder.format_experiment(exp)) # with st.expander("Visualize Metrics"): if exp.has_prop('tensorboard_args'): - tb = exp.get_prop('tensorboard_args') + tb = exp.properties.get('tensorboard_args') metrics = parser.parse_fron_tensorboard(tb['log_dir']) elif exp.has_prop('logger_args'): - tb = exp.get_prop('logger_args') + tb = exp.properties.get('logger_args') metrics = parser.parse_from_log(tb['log_dir']) else: metrics = {} @@ -52,10 +52,10 @@ def make_test(self, test_root: str): k, v = metrics[i + 1] m.write(k) m.line_chart(np.array([vv.value for vv in v])) - # if i + 2 >= len(metrics): - # break - # k, v = metrics[i + 2] - # r.line_chart({'k': np.array([vv.value for vv in v])}) + # if i + 2 >= len(metrics): + # break + # k, v = metrics[i + 2] + # r.line_chart({'k': np.array([vv.value for vv in v])}) def select_head(self): left, right = st.columns([1, 3]) diff --git a/src/lumo/sketch/vis/parser.py b/src/lumo/sketch/vis/parser.py index c549d9c..b1f9406 100644 --- a/src/lumo/sketch/vis/parser.py +++ b/src/lumo/sketch/vis/parser.py @@ -30,10 +30,10 @@ def find_metric_fron_test_root(test_root): exp = Experiment.from_disk(test_root) if exp.has_prop('tensorboard_args'): - tb = exp.get_prop('tensorboard_args') + tb = exp.properties.get('tensorboard_args') metrics = parse_fron_tensorboard(tb['log_dir']) elif exp.has_prop('logger_args'): - tb = exp.get_prop('logger_args') + tb = exp.properties.get('logger_args') metrics = parse_from_log(tb['log_dir']) else: fs = [i for i in os.listdir(exp.test_root)] From 2be4cf5330653286992e6a9a39cbabfe3d73cc26 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:00:35 +0800 Subject: [PATCH 66/99] Deprecate database, use exp.metric instead. --- src/lumo/trainer/callbacks.py | 19 +++++++++------ src/lumo/trainer/trainer.py | 45 +++++++++++++---------------------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 03a19af..1139347 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -10,6 +10,7 @@ from functools import wraps from typing import NewType, Any, Optional, Dict, Union +import psutil from torch.utils.data import DataLoader from lumo.core import Meter, MetricType, Record, TrainStage, wrap_result, ParamsType @@ -605,7 +606,7 @@ def log_matrix(self, metrics: Dict, step: int, namespace: str): def on_hooked(self, source: Trainer, params: ParamsType): super().on_hooked(source, params) - source.exp.set_prop('AutoRecord', self.__class__.__name__) + source.exp.dump_string('AutoRecord', self.__class__.__name__) def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType, *args, **kwargs): super().on_train_step_end(trainer, func, params, metric, *args, **kwargs) @@ -806,7 +807,7 @@ def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseE })) -class CUDAMemoryRecord(TrainCallback): +class ResourceRecord(TrainCallback): """ Record CUDA GPU maximum memory used during training. @@ -817,16 +818,20 @@ class CUDAMemoryRecord(TrainCallback): def on_hooked(self, source: Trainer, params: ParamsType): super().on_hooked(source, params) - self.max_memory = 0 self.device = source.device self.pid = os.getpid() self.mem = DeviceMem() def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): super().on_train_epoch_end(trainer, func, params, record, *args, **kwargs) - self.max_memory = max(self.max_memory, self.mem.get_pid_device_mem(self.pid, self.device)) - trainer.database.update('memory', self.max_memory) - trainer.exp.dump_info('max_memory', self.max_memory) + trainer.exp.dump_metric('CUDA_memory', self.mem.get_pid_device_mem(self.pid, self.device), cmp='max') + + # 获取进程的内存信息 + memory_info = psutil.Process(self.pid).memory_info() + + # 打印内存信息 + trainer.exp.dump_metric('CPU_memory_rss', memory_info.rss / 1024 / 1024) + trainer.exp.dump_metric('CPU_memory_vms', memory_info.vms / 1024 / 1024) class SkipWhenParamsEq(TrainCallback, InitialCallback): @@ -845,7 +850,7 @@ def on_hooked(self, source: Trainer, params: ParamsType): if isinstance(old, str) and is_test_root(old) and os.path.exists(old): source.stop_train() source.stop_train_epoch() - source.database.update('skiped', True, flush=True) + source.exp.dump_info('early_stop', True) source.logger.info( f'Find finished test with equal params ({current}) from {olds[current]}. ' f'To runStored in:\n {self.fn}') diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 63c0662..58d63c7 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -79,10 +79,12 @@ def __init__(self, params: ParamsType, dm: DataModule = None): self.params.iparams() self.exp = TrainerExperiment(self.generate_exp_name()) - self.database = TableRow(self.exp.project_name, self.exp.exp_name, self.exp.test_name_with_dist) - self.metric_board = Metrics(self.exp.test_root) + self._database = TableRow(self.exp.metrics_fn, persistent=self.is_main) + self.metric_board = Metrics(self.exp.test_root, persistent=self.is_main) + self.metric = self.exp.metric + self.exp.dump_info('metric_board', self.metric_board.fpath) - self.exp.dump_info('table_row', self.database.fpath) + self.exp.dump_info('table_row', self._database.fpath) self.rnd = RndManager() self.train_epoch_toggle = False @@ -104,7 +106,7 @@ def __init__(self, params: ParamsType, dm: DataModule = None): self.set_epoch_idx(0) self.set_idx(0) if params.get('debug', False): - self.exp.set_prop('debug', True) + self.exp.dump_info('debug', True) @property def metrics(self): @@ -112,7 +114,12 @@ def metrics(self): @property def db(self): - return self.database + return self._database + + @property + def database(self): + warnings.warn('TableRow is deprecated and will be removed soon, please use self.metric instead') + return self._database @property def saver(self) -> Saver: @@ -500,11 +507,10 @@ def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapp def on_trainer_exception(self, func: Callable, exception: BaseException): """Updates database with error information when an exception occurs during training.""" - self.database.update_dict(dict(end=datetime.now(), - finished=False, - error=str(exception), - trainer_frame=str(func)), - flush=True) + self.exp.dump_info('exception', dict(end=datetime.now(), + finished=False, + error=str(exception), + trainer_frame=str(func))) @property def is_initialized(self): @@ -526,22 +532,9 @@ def initialize(self): return self.exp.start() - commit_info = self.exp.get_prop('git') - commit_hex = None - if commit_info is not None and 'commit' in commit_info: - commit_hex = commit_info['commit'] - self.database.update('commit_hex', commit_hex) - self.database.update_dict(dict( - test_name=self.exp.test_name, - exp_name=self.exp.exp_name, - project=self.exp.project_name, - path=self.exp.test_root, - start=datetime.now())) - self.database.set_params(self.params.to_dict()) - self.database.update('command', ' '.join(sys.argv)) params_hash = self.params.hash() - self.database.update('params_hash', params_hash) self.exp.dump_string('params_hash', params_hash) + self.icallbacks(self.params) self.set_property('initial.callbacks', True) self.imodels(self.params) @@ -631,8 +624,6 @@ def train(self, dm: Union[DataModule, DataLoaderType] = None, params: ParamsType # update when train finished self.exp.end() - self.database.update_dict(dict(end=datetime.now(), finished=True), flush=True) - self.database.flush() return self._prop def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, @@ -673,10 +664,8 @@ def train_epoch(self, loader: DataLoaderType, params: ParamsType = None, self._prop['global_steps'] += 1 metric = self.train_step(batch, params) record.record(metric) - self.database.flush() record.flush() - self.database.update_dict(dict(eidx=self.eidx, end=datetime.now())) return record def set_property(self, key: str, value: any) -> None: From d0f71efbc117babcd11ee0b3a81ba3d3be9eaee1 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:01:04 +0800 Subject: [PATCH 67/99] Add function to run command in subprocess --- src/lumo/utils/subprocess.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py index d7e8f5a..39da1f3 100644 --- a/src/lumo/utils/subprocess.py +++ b/src/lumo/utils/subprocess.py @@ -5,7 +5,9 @@ def run_command(command, cwd=None): - proc = subprocess.Popen(command, cwd=cwd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.Popen(command, + cwd=cwd, + shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) try: while proc.poll() is None: # Wait for output from the process From 30fd6ae98ee8382406e44034290e269038627c79 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:01:10 +0800 Subject: [PATCH 68/99] Fix tests --- tests/core/test_disk.py | 10 +++------- tests/proc/test_proc.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/core/test_disk.py b/tests/core/test_disk.py index c8ba42d..a06bed8 100644 --- a/tests/core/test_disk.py +++ b/tests/core/test_disk.py @@ -1,3 +1,5 @@ +import os.path + from lumo.core.disk import Metrics, TableRow import numpy as np import tempfile @@ -5,15 +7,10 @@ from lumo.utils import safe_io as IO glob['metric_root'] = tempfile.mkdtemp() -from lumo.utils.fmt import strftime def test_table_row(): - row = TableRow('lumo-test', 'core', strftime()) - - # test update - row.update('a', 'b') - assert row['a'] == 'b' + row = TableRow(os.path.join(tempfile.mkdtemp(), 'lumo-test')) # test update_metric ## test max @@ -38,6 +35,5 @@ def test_table_row(): # test storage row.flush() storage = IO.load_pkl(row.fpath) - assert storage['a'] == row['a'] assert storage['metric']['acc'] == row['metric']['acc'] assert (storage['metric']['clsAcc'] == row['metric']['clsAcc']).all() diff --git a/tests/proc/test_proc.py b/tests/proc/test_proc.py index a620ea3..6335e47 100644 --- a/tests/proc/test_proc.py +++ b/tests/proc/test_proc.py @@ -25,7 +25,7 @@ def test_exproot(): def test_progressroot(): PROGRESS_ROOT = glob.get('progress_root', None) - expected = PROGRESS_ROOT or os.path.join(libhome(), 'progress') + expected = PROGRESS_ROOT or os.path.join(cache_dir(), 'progress') assert progressroot() == expected From 435f62ac399d58d8c6da068782c11d8ed1f29d6f Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:01:27 +0800 Subject: [PATCH 69/99] Add cli --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index f8f6165..a0d9f4c 100644 --- a/setup.py +++ b/setup.py @@ -30,5 +30,8 @@ def extract_version(): keywords='lumo', packages=find_packages('src'), entry_points={ + 'console_scripts': [ + 'lumo = lumo.cli:main', + ] }, ) From 959480819743547dac61390102db53233a1bb982 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:52:28 +0800 Subject: [PATCH 70/99] Add Trainer params to Experiment default properties --- src/lumo/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 58d63c7..3d85955 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -101,6 +101,7 @@ def __init__(self, params: ParamsType, dm: DataModule = None): if dist.is_main(): self.params.to_yaml(self.exp.params_fn) + self.exp.dump_info('params', self.params.to_dict()) self.set_global_steps(0) self.set_epoch_idx(0) From 8e72a74b8a0caf16b0516d5654318637e4fbf064 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:53:08 +0800 Subject: [PATCH 71/99] Add db_root for storage --- src/lumo/proc/config.py | 4 +++- src/lumo/proc/path.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/lumo/proc/config.py b/src/lumo/proc/config.py index c6633a5..925c4a3 100644 --- a/src/lumo/proc/config.py +++ b/src/lumo/proc/config.py @@ -119,11 +119,13 @@ def debug_mode(base_dir=None, disable_git=True): None """ glob['exp_root'] = tempfile.mkdtemp(dir=base_dir) + glob['db_root'] = tempfile.mkdtemp(dir=base_dir) glob['progress_root'] = tempfile.mkdtemp(dir=base_dir) + glob['metric_root'] = tempfile.mkdtemp(dir=base_dir) + glob['home'] = tempfile.mkdtemp(dir=base_dir) glob['cache_dir'] = tempfile.mkdtemp(dir=base_dir) glob['blob_root'] = tempfile.mkdtemp(dir=base_dir) - glob['metric_root'] = tempfile.mkdtemp(dir=base_dir) # glob['HOOK_LOCKFILE'] = False glob['HOOK_LASTCMD_DIR'] = tempfile.mkdtemp(dir=base_dir) # glob['HOOK_RECORDABORT'] = False diff --git a/src/lumo/proc/path.py b/src/lumo/proc/path.py index 2a7bf18..8cff01b 100644 --- a/src/lumo/proc/path.py +++ b/src/lumo/proc/path.py @@ -111,6 +111,16 @@ def metricroot(): return res +def dbroot(): + DB_ROOT = glob.get('db_root', None) + if DB_ROOT: + res = DB_ROOT + else: + res = os.path.join(libhome(), 'database') + os.makedirs(res, exist_ok=True) + return res + + def local_dir(): """ Project root, default is the parent directory of .git. From 019ab1ba08ea9c2f3283fd5c4419a8315ea6e751 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 17:53:31 +0800 Subject: [PATCH 72/99] Add full and progress methods --- src/lumo/exp/experiment.py | 9 ++- src/lumo/exp/watch.py | 127 +++++++++++++++++++++++++++---------- 2 files changed, 102 insertions(+), 34 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 00c1d1d..549961d 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -471,7 +471,7 @@ def _trigger_change(self, func): @wraps(func) def inner(*args, **kwargs): fn = self.progress_file(f'{self.test_name}.heartbeat', 'hb', self.exp_name) - io.dump_text(strftime(), fn) + io.dump_text(self.test_root, fn) func(*args, **kwargs) return inner @@ -894,6 +894,13 @@ def from_disk(cls, path): self._metric = Metric(self.metrics_fn) return self + def dict(self): + return { + **self.properties, + 'is_alive': self.is_alive, + 'metrics': self.metric.value, + } + class SimpleExperiment(Experiment): """ diff --git a/src/lumo/exp/watch.py b/src/lumo/exp/watch.py index e07e2dd..916ef81 100644 --- a/src/lumo/exp/watch.py +++ b/src/lumo/exp/watch.py @@ -19,43 +19,104 @@ """ import os.path +from typing import List, Dict -from lumo.proc.path import progressroot +import pandas as pd +from dbrecord import PDict + +from lumo.proc.path import progressroot, exproot, dbroot +from .experiment import Experiment +from .finder import is_test_name +from lumo.utils import safe_io as IO +from lumo.analyse.collect import collect_table_rows PID_ROOT = os.path.join(progressroot(), 'pid') HB_ROOT = os.path.join(progressroot(), 'hb') EXP_ROOT = os.path.join(progressroot()) -# class Watcher: -# """List and watch experiments with time order -# -# Cache test_information in -# metrics/.sqlite -# """ -# -# def load(self): -# pass -# -# def interactive(self): -# """interactive, mark, label, note in ipython environment.""" -# pass -# -# def server(self): -# """simple server which make you note your experiments""" -# pass -# -# def list_all(self, exp_root=None, limit=100) -> Dict[str, List[Experiment]]: -# """ -# Returns a dictionary of all experiments under exp_root directory. -# -# Args: -# exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory. -# -# Returns: -# A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects. -# """ -# return { -# _get_exp_name(exp_path): retrieval_tests_from_experiment(exp_path) -# for exp_path in list_experiment_paths(exp_root) -# } +class Watcher: + """List and watch experiments with time order + + Cache test_information in + metrics/.sqlite + """ + + def __init__(self, exp_root=None, hb_root=None, pid_root=None, db_root=None): + if exp_root is None: + exp_root = os.path.join(exproot(), 'hb') + + if hb_root is None: + hb_root = os.path.join(progressroot(), 'hb') + + if pid_root is None: + pid_root = os.path.join(progressroot(), 'pid') + + if db_root is None: + db_root = dbroot() + self.db_root = db_root + self.exp_root = exp_root + self.hb_root = hb_root + self.pid_root = pid_root + + def load(self): + res = {} + updates = {} + for root, dirs, fs in os.walk(self.hb_root): + if root == self.hb_root: + continue + for f in fs: + if f.endswith('heartbeat'): + hb_file = os.path.join(root, f) + test_root = IO.load_text(hb_file) + try: + exp = Experiment.from_disk(test_root) + updates.setdefault(exp.exp_name, []).append(exp.dict()) + except KeyboardInterrupt as e: + raise e + except: + continue + + for exp_name, tests in updates.items(): + dic = PDict(os.path.join(self.db_root, f'{exp_name}.sqlite')) + for test_name, test_prop in dic.items(): + res[test_name] = test_prop + + for test in tests: + dic[test['test_name']] = test + res[test['test_name']] = test + dic.flush() + + df = pd.DataFrame(res.values()) + return df + + def progress(self): + """return the alive process""" + res = [] + for pid in os.listdir(self.pid_root): + try: + test_root = IO.load_text(pid) + exp = Experiment.from_disk(test_root) + res.append(exp.dict()) + except: + continue + return pd.DataFrame(res) + + def interactive(self): + """interactive, mark, label, note in ipython environment.""" + pass + + def server(self): + """simple server which make you note your experiments""" + pass + + def list_all(self, exp_root=None, limit=100) -> Dict[str, List[Experiment]]: + """ + Returns a dictionary of all experiments under exp_root directory. + + Args: + exp_root: The root directory to search for experiments. Default is None, which uses the default experiment root directory. + + Returns: + A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects. + """ From d99709a5cdcfb5f0b01f4833619dff4c1a8ab234 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 21:59:10 +0800 Subject: [PATCH 73/99] Fix read path --- src/lumo/exp/watch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/exp/watch.py b/src/lumo/exp/watch.py index 916ef81..a7ef62c 100644 --- a/src/lumo/exp/watch.py +++ b/src/lumo/exp/watch.py @@ -95,7 +95,7 @@ def progress(self): res = [] for pid in os.listdir(self.pid_root): try: - test_root = IO.load_text(pid) + test_root = IO.load_text(os.path.join(self.pid_root, pid)) exp = Experiment.from_disk(test_root) res.append(exp.dict()) except: From 513ed19381a7086f5c991050c68d1d852fa72d3b Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 21:59:23 +0800 Subject: [PATCH 74/99] file hint --- src/lumo/utils/safe_io.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py index f589b46..3bfa8b4 100644 --- a/src/lumo/utils/safe_io.py +++ b/src/lumo/utils/safe_io.py @@ -54,8 +54,12 @@ def dump_state_dict(obj, fn): def load_json(fn): """Loads JSON data from the given file path and returns the resulting object.""" - with open(fn, 'r', encoding='utf-8') as r: - return json.load(r) + try: + with open(fn, 'r', encoding='utf-8') as r: + return json.load(r) + except json.JSONDecodeError as e: + e.msg = f'Error in file {fn}: {e.msg}' + raise e def load_yaml(fn): @@ -78,6 +82,7 @@ def load_text(fn): with open(fn, 'r', encoding='utf-8') as r: return ''.join(r.readlines()) + def dump_text(string: str, fn, append=False): """Write the given string to a file. @@ -166,6 +171,7 @@ def load_pkl(file, *, fix_imports=True, encoding="ASCII", errors="strict"): else: raise NotImplementedError("File type not supported.") + @contextmanager def cached(fn): """ From 3213d2cfbb68dc60d896ce689198fc74b31692f4 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:03:25 +0800 Subject: [PATCH 75/99] Add test_name --- src/lumo/exp/experiment.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 549961d..686c59b 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -896,6 +896,11 @@ def from_disk(cls, path): def dict(self): return { + 'path': { + 'test_root': self.test_root, + 'exp_root': self.exp_root, + 'blob_root': self.blob_root, + }, **self.properties, 'is_alive': self.is_alive, 'metrics': self.metric.value, From 83757ee3d679d1adbac99580c794ccfdb7bcb09d Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:08:20 +0800 Subject: [PATCH 76/99] Add additional exception info from json --- src/lumo/utils/safe_io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py index 3bfa8b4..80435a5 100644 --- a/src/lumo/utils/safe_io.py +++ b/src/lumo/utils/safe_io.py @@ -58,8 +58,7 @@ def load_json(fn): with open(fn, 'r', encoding='utf-8') as r: return json.load(r) except json.JSONDecodeError as e: - e.msg = f'Error in file {fn}: {e.msg}' - raise e + raise ValueError(f'Error in file {fn}') from e def load_yaml(fn): From b3e7451d65fe739603658e74465b806af1f6e350 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:13:15 +0800 Subject: [PATCH 77/99] Add additional exception info from json --- src/lumo/utils/safe_io.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lumo/utils/safe_io.py b/src/lumo/utils/safe_io.py index 80435a5..b87d588 100644 --- a/src/lumo/utils/safe_io.py +++ b/src/lumo/utils/safe_io.py @@ -53,7 +53,17 @@ def dump_state_dict(obj, fn): def load_json(fn): - """Loads JSON data from the given file path and returns the resulting object.""" + """ + Loads JSON data from the given file path and returns the resulting object. + + Args: + fn: file name + + Returns: + + Raises: + ValueError + """ try: with open(fn, 'r', encoding='utf-8') as r: return json.load(r) From 49693c2d05a57dd78228cde4fe048e45916e0a05 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:13:36 +0800 Subject: [PATCH 78/99] Skip damaged file --- src/lumo/exp/experiment.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 686c59b..98da5e4 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -577,7 +577,10 @@ def load_info(self, key: str, info_dir='info'): fn = self.test_file(f'{key}.json', info_dir) if not os.path.exists(fn): return {} - return io.load_json(fn) + try: + return io.load_json(fn) + except ValueError as e: + return {} def dump_note(self, note: str): fn = self.test_file('note.md') From c0fb37bab38d0e8084b88e7ab566dcbd92a91069 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:31:35 +0800 Subject: [PATCH 79/99] Add env --- src/lumo/utils/subprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py index 39da1f3..8f2409b 100644 --- a/src/lumo/utils/subprocess.py +++ b/src/lumo/utils/subprocess.py @@ -4,9 +4,10 @@ import signal -def run_command(command, cwd=None): +def run_command(command, cwd=None, env=None): proc = subprocess.Popen(command, cwd=cwd, + env=env, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) try: while proc.poll() is None: From bd3ce2887b38993cc7dc6b7ba46dfb9d8fdb5ec6 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:31:56 +0800 Subject: [PATCH 80/99] Add load note --- src/lumo/exp/experiment.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 98da5e4..3227ad2 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -582,8 +582,15 @@ def load_info(self, key: str, info_dir='info'): except ValueError as e: return {} + def load_note(self): + fn = self.test_file('note.md') + if os.path.exists(fn): + return io.load_text(fn) + return '' + def dump_note(self, note: str): fn = self.test_file('note.md') + self.set_prop('note', note) io.dump_text(note, fn) def dump_string(self, key: str, info: str, append=False): @@ -761,7 +768,8 @@ def rerun(self, arg_list: List[str]): command = ' '.join([old_exec['exec_bin'], old_exec['exec_file'], *old_exec['exec_argv'], *arg_list]) env = os.environ.copy() env[Experiment.ENV_TEST_NAME_KEY] = new_exp.test_name - return run_command(command, cwd=old_exec['cwd']) + + return run_command(command, cwd=old_exec['cwd'], env=env) @call_on_main_process_wrap def initial(self): @@ -893,6 +901,8 @@ def from_disk(cls, path): key = os.path.splitext(f)[0] self.set_prop(key, self.load_string(key)) + self.set_prop('note', self.load_note()) + # load metric self._metric = Metric(self.metrics_fn) return self From 87d387027bf8d3cd02214dec6856db3053b376a2 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:34:43 +0800 Subject: [PATCH 81/99] Add load note --- src/lumo/exp/experiment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 3227ad2..4de3684 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -130,6 +130,8 @@ def __init__(self, exp_name: str, root=None, test_name=None): self.add_exit_hook(self.end) self.logger = Logger() + print(os.environ) + def __getitem__(self, item): """ Gets a property of the experiment. From 8cf3e39e913e26ac0f9b62ff32c57ff2f4880ce0 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:38:12 +0800 Subject: [PATCH 82/99] Add load note --- src/lumo/exp/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 4de3684..e84c057 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -770,7 +770,7 @@ def rerun(self, arg_list: List[str]): command = ' '.join([old_exec['exec_bin'], old_exec['exec_file'], *old_exec['exec_argv'], *arg_list]) env = os.environ.copy() env[Experiment.ENV_TEST_NAME_KEY] = new_exp.test_name - + print(env) return run_command(command, cwd=old_exec['cwd'], env=env) @call_on_main_process_wrap From ce5fde1ab9bc3ef4f899e133da74f60c5eda3cf5 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:40:59 +0800 Subject: [PATCH 83/99] Update test_name assignment manner --- src/lumo/exp/experiment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index e84c057..2db6576 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -112,6 +112,8 @@ def __init__(self, exp_name: str, root=None, test_name=None): self._prop = {} self._prop['exp_name'] = exp_name + if test_name is None: + test_name = os.environ.get(Experiment.ENV_TEST_NAME_KEY, None) self._prop['test_name'] = test_name self._hooks = {} @@ -261,8 +263,7 @@ def _test_name(self): """ str: Gets the name of the current test being run. """ - given_test_name = os.environ.get(Experiment.ENV_TEST_NAME_KEY, None) - return self._prop.get('test_name', given_test_name) + return self._prop.get('test_name') @_test_name.setter def _test_name(self, value): @@ -770,7 +771,6 @@ def rerun(self, arg_list: List[str]): command = ' '.join([old_exec['exec_bin'], old_exec['exec_file'], *old_exec['exec_argv'], *arg_list]) env = os.environ.copy() env[Experiment.ENV_TEST_NAME_KEY] = new_exp.test_name - print(env) return run_command(command, cwd=old_exec['cwd'], env=env) @call_on_main_process_wrap From 711ea3153afdcbc7baf908ca3b0706ae1870d90b Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:44:46 +0800 Subject: [PATCH 84/99] Avoid create empty test --- src/lumo/exp/experiment.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 2db6576..466952d 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -243,7 +243,7 @@ def test_name(self): if is_dist(): # if train in distribute mode, subprocess will wait a few seconds to wait main process. flag_fn = f'.{os.getppid()}' if is_main(): - self._test_name = self._create_test_name() + self._test_name = self._create_test_name(self.exp_root) fn = self.exp_file(flag_fn) with open(fn, 'w') as w: w.write(self._test_name) @@ -254,7 +254,7 @@ def test_name(self): with open(fn, 'r') as r: self._test_name = r.readline().strip() else: - self._test_name = self._create_test_name() + self._test_name = self._create_test_name(self.exp_root) return self._test_name @@ -479,7 +479,8 @@ def inner(*args, **kwargs): return inner - def _create_test_name(self): + @classmethod + def _create_test_name(cls, exp_root): """ Generates a unique test name based on the current date and time. regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t @@ -489,7 +490,7 @@ def _create_test_name(self): """ from lumo.proc.date import timehash from ..utils.fmt import strftime - fs = os.listdir(self.exp_root) + fs = os.listdir(exp_root) date_str = strftime('%y%m%d') fs = [i for i in fs if i.startswith(date_str)] _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t" @@ -758,7 +759,7 @@ def Class(cls): def rerun(self, arg_list: List[str]): """rerun this test in another""" # self.properties[''] - new_test_name = self._create_test_name() + new_test_name = self._create_test_name(self.exp_root) new_exp = Experiment(self.exp_name, root=self._root, test_name=new_test_name) self.dump_info('deprecated', {'rerun_at': new_exp.test_name}) old_rerun_info = self.properties.get('rerun', None) @@ -891,8 +892,7 @@ def from_disk(cls, path): test_root = Path(path) root = test_root.parent.parent.parent.as_posix() - self = cls(test_root.parent.name, root=root) - self._test_name = test_root.name + self = cls(test_root.parent.name, root=root, test_name=test_root.name) # load prop for f in os.listdir(self.test_dir('info')): From 6bffd276ef8e1ab702426c308d180172a9d751d5 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:45:36 +0800 Subject: [PATCH 85/99] Remove debug lines --- src/lumo/exp/experiment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 466952d..f565ce2 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -132,8 +132,6 @@ def __init__(self, exp_name: str, root=None, test_name=None): self.add_exit_hook(self.end) self.logger = Logger() - print(os.environ) - def __getitem__(self, item): """ Gets a property of the experiment. From 3b0000c7bd6bdfa530f5c2de7345cae672783d80 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:48:17 +0800 Subject: [PATCH 86/99] Update manner --- src/lumo/exp/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index f565ce2..7715cbf 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -759,7 +759,7 @@ def rerun(self, arg_list: List[str]): # self.properties[''] new_test_name = self._create_test_name(self.exp_root) new_exp = Experiment(self.exp_name, root=self._root, test_name=new_test_name) - self.dump_info('deprecated', {'rerun_at': new_exp.test_name}) + self.dump_info('deprecated', {'rerun_at': {new_exp.test_name: True}}, append=True) old_rerun_info = self.properties.get('rerun', None) count = 1 if old_rerun_info is not None: From 46776b3751496f1c68870f1ab5164d673df43312 Mon Sep 17 00:00:00 2001 From: sailist Date: Fri, 10 Mar 2023 22:48:27 +0800 Subject: [PATCH 87/99] No strip --- src/lumo/utils/subprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py index 8f2409b..1b5e26d 100644 --- a/src/lumo/utils/subprocess.py +++ b/src/lumo/utils/subprocess.py @@ -39,14 +39,14 @@ def run_command(command, cwd=None, env=None): # Wait for output from the process rlist, _, _ = select.select([proc.stdout, proc.stderr], [], [], 0.1) for stream in rlist: - line = stream.readline().decode('utf-8').strip() + line = stream.readline().decode('utf-8') if line: print(line) # Read the remaining output for stream in [proc.stdout, proc.stderr]: while True: - line = stream.readline().decode('utf-8').strip() + line = stream.readline().decode('utf-8') if not line: break print(line) From 03ddaf8fd8a94c8a67fb252e05dd059aee9953bf Mon Sep 17 00:00:00 2001 From: sailist Date: Sat, 11 Mar 2023 00:05:00 +0800 Subject: [PATCH 88/99] Disable Trainer instance --- src/lumo/trainer/trainer.py | 6 ++++++ tests/exp/test_watcher.py | 27 +++++++++++++++++++++++++++ tests/trainer/test_finder.py | 5 +++-- tests/trainer/test_trainer.py | 6 +++++- 4 files changed, 41 insertions(+), 3 deletions(-) create mode 100644 tests/exp/test_watcher.py diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 3d85955..7ac6aed 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -64,6 +64,12 @@ class Trainer(_BaseTrainer): 'process_loader', 'regist_dataloader' } + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if cls.__name__ == 'Trainer': + raise TypeError( + f"Can't instantiate abstract class {cls.__name__} directly, please create a subclass of it.") + def __init__(self, params: ParamsType, dm: DataModule = None): if dm is None: dm = DataModule(params) diff --git a/tests/exp/test_watcher.py b/tests/exp/test_watcher.py new file mode 100644 index 0000000..e33474e --- /dev/null +++ b/tests/exp/test_watcher.py @@ -0,0 +1,27 @@ +from lumo import Trainer, TrainerParams +from lumo.exp.watch import Watcher +from lumo.proc.config import debug_mode + + +class MyTrainer(Trainer): + pass + + +def trainer(): + params = TrainerParams() + t = MyTrainer(params) + return t + + +def test_exp(): + debug_mode() + for i in range(10): + t = trainer() + t.train() + print(t.exp.test_name) + + w = Watcher() + df = w.load() + print(df.columns) + # print(sorted(list(df['test_name']))) + assert len(df) == 10 diff --git a/tests/trainer/test_finder.py b/tests/trainer/test_finder.py index 2588bae..82781a3 100644 --- a/tests/trainer/test_finder.py +++ b/tests/trainer/test_finder.py @@ -26,9 +26,10 @@ def test_finder(): params.rnd = random.random() ATrainer(params).train() BTrainer(params).train() - all_tests = finder.list_all() - assert len(all_tests) == 2 + # print([ATrainer.generate_exp_name(), BTrainer.generate_exp_name()]) + print(ATrainer.__exp_name__) + assert len(all_tests) == len({ATrainer.generate_exp_name(), BTrainer.generate_exp_name()}) assert ATrainer.generate_exp_name() in all_tests assert BTrainer.generate_exp_name() in all_tests assert len(all_tests[ATrainer.generate_exp_name()]) == 5 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fec64ca..c7df8fe 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -177,8 +177,12 @@ def test_trainer_params(): optim = params.optim.build(module.parameters()) +class MyTrainer(Trainer): + pass + + def test_trainer_state_dict(): - trainer = Trainer(TrainerParams()) + trainer = MyTrainer(TrainerParams()) device_a = trainer.device_a = torch.device('cpu') ndarray_a = trainer.ndarray_a = np.array([1, 2, 3]) tensor_a = trainer.tensor_a = torch.tensor([1, 2, 3]) From 7c6073e0b70792aadcc0299c50e85f8b792a0b7d Mon Sep 17 00:00:00 2001 From: sailist Date: Sat, 11 Mar 2023 00:42:27 +0800 Subject: [PATCH 89/99] print raw string without breakline --- src/lumo/utils/subprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lumo/utils/subprocess.py b/src/lumo/utils/subprocess.py index 1b5e26d..5ed5e9f 100644 --- a/src/lumo/utils/subprocess.py +++ b/src/lumo/utils/subprocess.py @@ -41,7 +41,7 @@ def run_command(command, cwd=None, env=None): for stream in rlist: line = stream.readline().decode('utf-8') if line: - print(line) + print(line, end='') # Read the remaining output for stream in [proc.stdout, proc.stderr]: @@ -49,7 +49,7 @@ def run_command(command, cwd=None, env=None): line = stream.readline().decode('utf-8') if not line: break - print(line) + print(line, end='') # Get the return code of the process return_code = proc.wait() From 6490469166a2cb7b236332f80b0cf4f0b4ef9d9c Mon Sep 17 00:00:00 2001 From: sailist Date: Sat, 11 Mar 2023 00:54:00 +0800 Subject: [PATCH 90/99] Logger behaviour update --- src/lumo/trainer/callbacks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 1139347..36c7096 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -402,8 +402,8 @@ def update(self, trainer: Trainer): if self.c % self.breakin == 0 or ( TrainStage.train in self.stage and ((trainer.idx + 1) == self.stage[TrainStage.train])): - trainer.logger.info(self.cur_tqdm.full_str()) - # trainer.logger.newline() + trainer.logger.inline(self.cur_tqdm.full_str()) + trainer.logger.newline() def flush(self, trainer: Trainer): """Flush""" From 308f74720616f55fd2843ef36cfbd6ec12429868 Mon Sep 17 00:00:00 2001 From: sailist Date: Sat, 11 Mar 2023 01:27:26 +0800 Subject: [PATCH 91/99] Experiment tags --- src/lumo/exp/experiment.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 7715cbf..61ce79d 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -416,7 +416,7 @@ def metric(self): return self._metric @property - def note(self): + def note_fn(self): fn = self.test_file('note.md') if os.path.exists(fn): return io.load_text(fn) @@ -590,6 +590,9 @@ def load_note(self): return io.load_text(fn) return '' + def dump_tags(self, *tags): + self.dump_info('tags', tags) + def dump_note(self, note: str): fn = self.test_file('note.md') self.set_prop('note', note) From 08aeb3d6db5fcdada475df4bf85567d5a864716a Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 20:56:37 +0800 Subject: [PATCH 92/99] Deprecated finder --- tests/trainer/test_finder.py | 42 ------------------------------------ 1 file changed, 42 deletions(-) delete mode 100644 tests/trainer/test_finder.py diff --git a/tests/trainer/test_finder.py b/tests/trainer/test_finder.py deleted file mode 100644 index 82781a3..0000000 --- a/tests/trainer/test_finder.py +++ /dev/null @@ -1,42 +0,0 @@ -import random - -from lumo import Trainer, ParamsType, TrainerParams, Experiment -from lumo.exp import finder -from lumo.proc.config import debug_mode - - -class ATrainer(Trainer): - - def icallbacks(self, params: ParamsType): - super().icallbacks(params) - - -class BTrainer(Trainer): - - def icallbacks(self, params: ParamsType): - super().icallbacks(params) - - -def test_finder(): - debug_mode() - - for i in range(5): - params = TrainerParams() - params.epoch = i - params.rnd = random.random() - ATrainer(params).train() - BTrainer(params).train() - all_tests = finder.list_all() - # print([ATrainer.generate_exp_name(), BTrainer.generate_exp_name()]) - print(ATrainer.__exp_name__) - assert len(all_tests) == len({ATrainer.generate_exp_name(), BTrainer.generate_exp_name()}) - assert ATrainer.generate_exp_name() in all_tests - assert BTrainer.generate_exp_name() in all_tests - assert len(all_tests[ATrainer.generate_exp_name()]) == 5 - assert len(all_tests[BTrainer.generate_exp_name()]) == 5 - - assert isinstance(all_tests[ATrainer.generate_exp_name()][0], Experiment) - for exp in all_tests[ATrainer.generate_exp_name()]: - params = TrainerParams().from_yaml(exp.properties['params.yaml']) - assert params.hash() == exp.properties['params_hash'] - assert finder.find_path_from_test_name(exp.test_name) == exp.test_root From 7358bf7590d5714a599a5d08ab22e0e9c52b57d0 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 20:58:17 +0800 Subject: [PATCH 93/99] Reconstruct Experiment --- src/lumo/exp/experiment.py | 435 +++++++++++-------------------- src/lumo/exp/exphook.py | 17 +- src/lumo/exp/watch.py | 450 +++++++++++++++++++++++++++++++-- src/lumo/trainer/components.py | 8 +- src/lumo/trainer/trainer.py | 85 ++----- src/lumo/utils/fmt.py | 16 +- src/lumo/utils/repository.py | 2 +- tests/exp/test_watcher.py | 2 +- tests/trainer/test_trainer.py | 6 +- 9 files changed, 630 insertions(+), 391 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 61ce79d..2a904bf 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -9,13 +9,12 @@ import sys import time import traceback -from pathlib import Path -from typing import Union, Any, List +from typing import Any, List from functools import wraps from lumo.decorators.process import call_on_main_process_wrap from lumo.proc import glob from lumo.proc.dist import is_dist, is_main, local_rank -from lumo.proc.path import blobroot, libhome, progressroot +from lumo.proc.path import blobroot, cache_dir, libhome from lumo.proc.path import exproot, local_dir from lumo.utils import safe_io as io from lumo.utils.fmt import can_be_filename, strftime @@ -25,27 +24,26 @@ from .metric import Metric -def checkdir(path: Union[Path, str]): +class Experiment: """ - Create a directory at the specified path if it does not already exist. + Represents an experiment and manages its directory structure. An experiment consists of multiple tests, each of which + has its own directory to store information related to that test. - Args: - path (Union[Path, str]): The path to the directory to be created. - Returns: - Path: The Path object representing the created directory. - """ - if isinstance(path, str): - os.makedirs(path, exist_ok=True) - elif isinstance(path, Path): - path.mkdir(parents=True, exist_ok=True) - return path + - + - progress + - + - {test-1}.hb + - {test-1}.pid + - + - + - (info_dir) + + - + - + - (blob_dir) -class Experiment: - """ - Represents an experiment and manages its directory structure. An experiment consists of multiple tests, each of which - has its own directory to store information related to that test. (By default), the directory structure is as following: .lumo (libroot) @@ -93,15 +91,13 @@ class Experiment: ENV_TEST_NAME_KEY = 'LUMO_EXP_TEST_NAME' - def __init__(self, exp_name: str, root=None, test_name=None): + def __init__(self, exp_name: str, test_name=None, paths=None): """ Initializes a new instance of the Experiment class. Args: exp_name (str): The name of the experiment. This should be a legal filename and contain only letters or underscores. - root (str, optional): The root directory where the experiment's directories will be created. Defaults to - None, in which case the root directory is set to the library's home directory. Raises: ValueError: If the experiment name is not a legal filename. @@ -115,20 +111,18 @@ def __init__(self, exp_name: str, root=None, test_name=None): if test_name is None: test_name = os.environ.get(Experiment.ENV_TEST_NAME_KEY, None) self._prop['test_name'] = test_name - self._hooks = {} + if paths is None: + paths = {} + self._prop['paths'] = paths - self._metric = Metric(self.metrics_fn) + self._hooks = {} + self._metric = None # wrap - self._metric.dump_metrics = self._trigger_change(self._metric.dump_metrics) - self._metric.dump_metric = self._trigger_change(self._metric.dump_metric) self.dump_string = self._trigger_change(self.dump_string) self.dump_note = self._trigger_change(self.dump_note) self.dump_info = self._trigger_change(self.dump_info) - if root is None: - root = libhome() - self._root = Path(os.path.abspath(root)) self.add_exit_hook(self.end) self.logger = Logger() @@ -239,20 +233,21 @@ def test_name(self): """ if self._test_name is None: if is_dist(): # if train in distribute mode, subprocess will wait a few seconds to wait main process. - flag_fn = f'.{os.getppid()}' + flag_fn = os.path.join(self.cache_root, 'dist', f'.{os.getppid()}') + os.makedirs(os.path.dirname(flag_fn), exist_ok=True) + if is_main(): - self._test_name = self._create_test_name(self.exp_root) - fn = self.exp_file(flag_fn) - with open(fn, 'w') as w: + self._test_name = self._create_test_name(self.exp_dir) + + with open(flag_fn, 'w') as w: w.write(self._test_name) else: time.sleep(random.randint(2, 4)) - fn = self.exp_file(flag_fn) - if os.path.exists(fn): - with open(fn, 'r') as r: + if os.path.exists(flag_fn): + with open(flag_fn, 'r') as r: self._test_name = r.readline().strip() else: - self._test_name = self._create_test_name(self.exp_root) + self._test_name = self._create_test_name(self.exp_dir) return self._test_name @@ -273,64 +268,6 @@ def _test_name(self, value): """ self._prop['test_name'] = value - @property - def root_branch(self): - """ - Path: Gets the root branch directory of the experiment. - """ - val = self._root - return checkdir(val) - - @property - def lib_root(self): - """ - str: Gets the path of the library's root directory. - """ - return self.root_branch.as_posix() - - @property - def exp_branch(self): - """ - Path: Gets the experiment branch directory. - """ - val = Path(exproot()).joinpath(self.exp_name) - return checkdir(val) - - @property - def blob_branch(self): - """ - Path: Gets the blob branch directory, which is used to store big binary files like model state dicts. - """ - val = Path(blobroot()).joinpath(self.exp_name, self.test_name) - return checkdir(val) - - @property - def progress_branch(self): - """ - Path: Gets the progress branch directory, which is used to store progress information about running processes. - """ - val = Path(progressroot()) - return checkdir(val) - - @property - def test_branch(self): - """ - Path: Gets the test branch directory, which is used to store information related to the current test being run. - """ - val = self.exp_branch.joinpath(self.test_name) - return checkdir(val) - - @property - def tags(self): - """ - dict: Gets the tags associated with the experiment. - """ - tags = {} - for path in self.test_branch.joinpath('tags').glob('tag.*.json'): - ptags = io.load_json(path.as_posix()) # type: dict - tags.setdefault(path.suffixes[0].strip('.'), []).extend(ptags.keys()) - return tags - @property def repo_name(self): """ @@ -361,36 +298,6 @@ def project_root(self): """ return local_dir() - @property - def exp_root(self): - """ - Gets the path to the directory containing the tests for the experiment. - - Returns: - str: The path to the experiment root directory. - """ - return self.exp_branch.as_posix() - - @property - def test_root(self): - """ - Gets the path to the directory containing information about the current test. - - Returns: - str: The path to the test root directory. - """ - return self.test_branch.as_posix() - - @property - def blob_root(self): - """ - Gets the path to the directory containing large binary files associated with the experiment. - - Returns: - str: The path to the blob root directory. - """ - return self.blob_branch.as_posix() - @property def properties(self): """ @@ -401,10 +308,6 @@ def properties(self): """ return self._prop - @property - def metrics_fn(self): - return self.test_file('metric.pkl') - @property def metric(self): """ @@ -413,15 +316,12 @@ def metric(self): Returns: Metric: A dictionary containing all metrics of the experiment. """ + if self._metric is None: + self._metric = Metric(self.mk_ipath('metric.pkl')) + self._metric.dump_metrics = self._trigger_change(self._metric.dump_metrics) + self._metric.dump_metric = self._trigger_change(self._metric.dump_metric) return self._metric - @property - def note_fn(self): - fn = self.test_file('note.md') - if os.path.exists(fn): - return io.load_text(fn) - return fn - @property def paths(self) -> dict: """ @@ -431,10 +331,9 @@ def paths(self) -> dict: dict: A dictionary containing the paths to various directories associated with the experiment. """ return { - 'root': self.root_branch.as_posix(), - 'exp_root': self.exp_root, - 'test_root': self.test_root, - 'blob_root': self.blob_root, + 'info_root': self._prop['paths'].get('info_root', exproot()), + 'cache_root': self._prop['paths'].get('cache_root', cache_dir()), + 'blob_root': self._prop['paths'].get('blob_root', blobroot()), } @property @@ -471,14 +370,14 @@ def _trigger_change(self, func): # test_root update some files @wraps(func) def inner(*args, **kwargs): - fn = self.progress_file(f'{self.test_name}.heartbeat', 'hb', self.exp_name) - io.dump_text(self.test_root, fn) + fn = self.heartbeat_fn + io.dump_text(self.info_dir, fn) func(*args, **kwargs) return inner @classmethod - def _create_test_name(cls, exp_root): + def _create_test_name(cls, exp_dir): """ Generates a unique test name based on the current date and time. regex pattern: [0-9]{6}.[0-9]{3}.[a-z0-9]{3}t @@ -488,7 +387,7 @@ def _create_test_name(cls, exp_root): """ from lumo.proc.date import timehash from ..utils.fmt import strftime - fs = os.listdir(exp_root) + fs = os.listdir(exp_dir) date_str = strftime('%y%m%d') fs = [i for i in fs if i.startswith(date_str)] _test_name = f"{date_str}.{len(fs):03d}.{timehash()[-6:-4]}t" @@ -543,7 +442,7 @@ def dump_progress(self, ratio: float, update_from=None): res['last_edit_time'] = strftime() self.dump_info('progress', res, append=True) - def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop=True): + def dump_info(self, key: str, info: Any, append=False): """ Saves information about the experiment to a file. @@ -551,32 +450,27 @@ def dump_info(self, key: str, info: Any, append=False, info_dir='info', set_prop key (str): The key under which the information will be stored. info (Any): The information to store. append (bool, optional): Whether to append to the file or overwrite it. Defaults to False. - info_dir (str, optional): The name of the directory where the file will be stored. Defaults to 'info'. - set_prop (bool, optional): Whether to set the experiment property with the same key to the saved information. - Defaults to True. """ - fn = self.test_file(f'{key}.json', info_dir) + fn = self.mk_ipath('info', f'{key}.json') if append: - old_info = self.load_info(key, info_dir=info_dir) + old_info = self.load_info(key) old_info.update(info) info = old_info - if set_prop: - self[key] = info - # self.set_prop(key, info) + + self.set_prop(key, info) io.dump_json(info, fn) - def load_info(self, key: str, info_dir='info'): + def load_info(self, key: str): """ Loads information about the experiment from a file. Args: key (str): The key under which the information is stored. - info_dir (str, optional): The name of the directory where the file is stored. Defaults to 'info'. Returns: Any: The information stored under the specified key. """ - fn = self.test_file(f'{key}.json', info_dir) + fn = self.mk_ipath('info', f'{key}.json') if not os.path.exists(fn): return {} try: @@ -585,7 +479,7 @@ def load_info(self, key: str, info_dir='info'): return {} def load_note(self): - fn = self.test_file('note.md') + fn = self.mk_ipath('note.md') if os.path.exists(fn): return io.load_text(fn) return '' @@ -594,7 +488,7 @@ def dump_tags(self, *tags): self.dump_info('tags', tags) def dump_note(self, note: str): - fn = self.test_file('note.md') + fn = self.mk_ipath('note.md') self.set_prop('note', note) io.dump_text(note, fn) @@ -606,7 +500,7 @@ def dump_string(self, key: str, info: str, append=False): key (str): The key under which the string will be stored. info (str): The string to store. """ - fn = self.test_file(f'{key}.str', 'text') + fn = self.mk_ipath('text', f'{key}.str') io.dump_text(info, fn, append=append) if not append: self.set_prop(key, info) @@ -621,7 +515,7 @@ def load_string(self, key: str): Returns: str: The string stored under the specified key. """ - fn = self.test_file(f'{key}.str', 'text') + fn = self.mk_ipath('text', f'{key}.str') if not os.path.exists(fn): return '' return io.load_text(fn) @@ -632,125 +526,73 @@ def dump_metric(self, key, value, cmp: str, flush=True, **kwargs): def dump_metrics(self, dic: dict, cmp: str): return self.metric.dump_metrics(dic, cmp) - def exp_dir(self, *args): - """ - Gets the path to a directory in the experiment directory. - - Args: - *args: Any subdirectory names to include in the directory path. - - Returns: - str: The path to the specified directory. - """ - parent = self.exp_branch.joinpath(*args) - return checkdir(parent).as_posix() - - def exp_file(self, filename, *args): - """ - Gets the path to a file in the experiment directory. - - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. - - Returns: - str: The path to the specified file. - """ - parent = self.exp_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() - - def test_dir(self, *args): - """ - Gets the path to a directory in the test directory. - - Args: - *args: Any subdirectory names to include in the directory path. - - Returns: - str: The path to the specified directory. - """ - parent = self.test_branch.joinpath(*args) - return checkdir(parent).as_posix() - - def test_file(self, filename, *args): - """ - Gets the path to a file in the test directory. - - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. - - Returns: - str: The path to the specified file. - """ - parent = self.test_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() - - def root_dir(self, *args): - """ - Gets the path to a directory in the library's root directory. - - Args: - *args: Any subdirectory names to include in the directory path. + @property + def info_root(self): + return self.paths['info_root'] - Returns: - str: The path to the specified directory. - """ - parent = self.root_branch.joinpath(*args) - return checkdir(parent).as_posix() + @property + def cache_root(self): + return self.paths['cache_root'] - def root_file(self, filename, *args): - """ - Gets the path to a file in the library's root directory. + @property + def blob_root(self): + return self.paths['blob_root'] - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + @property + def pid_fn(self): + fn = os.path.join(self.cache_root, 'pid', self.exp_name, f'{self.test_name}.pid') + os.makedirs(os.path.dirname(fn), exist_ok=True) + return fn - Returns: - str: The path to the specified file. - """ - parent = self.root_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + @property + def heartbeat_fn(self): + fn = os.path.join(self.cache_root, 'heartbeat', self.exp_name, f'{self.test_name}.hb') + os.makedirs(os.path.dirname(fn), exist_ok=True) + return fn - def blob_dir(self, *args): - """ - Gets the path to a directory in the blob directory. - Args: - *args: Any subdirectory names to include in the directory path. + @property + def exp_dir(self): + d = os.path.join(self.info_root, self.exp_name) + os.makedirs(d, exist_ok=True) + return d - Returns: - str: The path to the specified directory. - """ - parent = self.blob_branch.joinpath(*args) - return checkdir(parent).as_posix() + @property + def info_dir(self): + d = os.path.join(self.info_root, self.exp_name, self.test_name) + os.makedirs(d, exist_ok=True) + return d - def blob_file(self, filename, *args): - """ - Gets the path to a file in the blob directory. + @property + def cache_dir(self): + d = os.path.join(self.cache_root, self.exp_name, self.test_name) + os.makedirs(d, exist_ok=True) + return d - Args: - filename (str): The name of the file. - *args: Any additional subdirectory names to include in the file path. + @property + def blob_dir(self): + d = os.path.join(self.blob_root, self.exp_name, self.test_name) + os.makedirs(d, exist_ok=True) + return d + + def _mk_path(self, *path: str, is_dir) -> str: + path = os.path.join(*path) + if is_dir: + os.makedirs(path, exist_ok=True) + else: + os.makedirs(os.path.dirname(path), exist_ok=True) + return path - Returns: - str: The path to the specified file. - """ - parent = self.blob_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + def mk_ipath(self, *path, is_dir=False): + return self._mk_path(self.info_dir, *path, is_dir=is_dir) - def progress_file(self, filename, *args): - """ - Gets the path to a file in the progress directory. + def mk_cpath(self, *path, is_dir=False): + return self._mk_path(self.cache_dir, *path, is_dir=is_dir) - Args: - filename (str): The name of the file. + def mk_bpath(self, *path, is_dir=False): + return self._mk_path(self.blob_dir, *path, is_dir=is_dir) - Returns: - str: The path to the specified file. - """ - parent = self.progress_branch.joinpath(*args) - return checkdir(parent).joinpath(filename).as_posix() + def mk_rpath(self, *path, is_dir=False): + return self._mk_path(libhome(), *path, is_dir=is_dir) @classmethod @property @@ -760,8 +602,8 @@ def Class(cls): def rerun(self, arg_list: List[str]): """rerun this test in another""" # self.properties[''] - new_test_name = self._create_test_name(self.exp_root) - new_exp = Experiment(self.exp_name, root=self._root, test_name=new_test_name) + new_test_name = self._create_test_name(self.exp_dir) + new_exp = Experiment(self.exp_name, test_name=new_test_name) self.dump_info('deprecated', {'rerun_at': {new_exp.test_name: True}}, append=True) old_rerun_info = self.properties.get('rerun', None) count = 1 @@ -780,8 +622,10 @@ def initial(self): """ Initializes the experiment by setting up progress, information, and PID tracking. """ - self.dump_info('progress', {'start': strftime(), 'finished': False}, append=True) - self.dump_progress(0) + self.dump_info('exp_name', self.exp_name) + self.dump_info('test_name', self.test_name) + self.dump_info('paths', self.paths) + self.dump_info('execute', { 'repo': self.project_root, 'cwd': os.getcwd(), @@ -795,16 +639,18 @@ def initial(self): 'obj': runtime_pid_obj(), }) + # register start + self.dump_info('progress', {'start': strftime(), 'finished': False}, append=True) + self.dump_progress(0) # register progress - # register this process - io.dump_text(self.test_root, self.progress_file(f'{os.getpid()}', 'pid')) + io.dump_text(self.info_dir, self.pid_fn) @call_on_main_process_wrap def start(self): """ Starts the experiment. """ - if self.properties.get('start', False): + if self.properties.get('progress', None) is not None: return self.initial() self.set_prop('start', True) @@ -824,12 +670,14 @@ def end(self, end_code=0, *args, **extra): """ if not self.is_alive: return - if not self.properties.get('start', False): + if not self.properties.get('progress', None) is None: return - if self.properties.get('end', False): + if self.properties['progress'].get('end', False): return + self.set_prop('end', True) - self.dump_progress(1) + if end_code == 0: + self.dump_progress(1) self.dump_info('progress', {'end': strftime(), 'finished': end_code == 0}, append=True) for hook in self._hooks.values(): # type: BaseExpHook @@ -873,6 +721,14 @@ def exp_func(): atexit.register(exp_func) + @classmethod + def from_cache(cls, dic: dict): + paths = dic.pop('paths', {}) + _ = dic.pop('metrics') + self = cls(exp_name=dic['exp_name'], test_name=dic['test_name'], paths=paths) + self._prop.update(dic) + return self + @classmethod def from_disk(cls, path): """ @@ -890,33 +746,34 @@ def from_disk(cls, path): from .finder import is_test_root if not is_test_root(path): raise ValueError(f'{path} is not a valid test_root') + path = os.path.abspath(path) + exp_dir = os.path.dirname(path) - test_root = Path(path) - root = test_root.parent.parent.parent.as_posix() - self = cls(test_root.parent.name, root=root, test_name=test_root.name) + paths_fn = os.path.join(path, 'info', f'paths.json') + paths = io.load_json(paths_fn) + self = cls(os.path.basename(exp_dir), test_name=os.path.basename(path), paths=paths) # load prop - for f in os.listdir(self.test_dir('info')): + for f in os.listdir(self.mk_ipath('info', is_dir=True)): key = os.path.splitext(f)[0] self.set_prop(key, self.load_info(key)) - for f in os.listdir(self.test_dir('text')): + for f in os.listdir(self.mk_ipath('text', is_dir=True)): key = os.path.splitext(f)[0] self.set_prop(key, self.load_string(key)) self.set_prop('note', self.load_note()) - # load metric - self._metric = Metric(self.metrics_fn) return self + def cache(self): + return { + **self.properties, + 'metrics': self.metric.value, + } + def dict(self): return { - 'path': { - 'test_root': self.test_root, - 'exp_root': self.exp_root, - 'blob_root': self.blob_root, - }, **self.properties, 'is_alive': self.is_alive, 'metrics': self.metric.value, diff --git a/src/lumo/exp/exphook.py b/src/lumo/exp/exphook.py index 86c1434..68500fb 100644 --- a/src/lumo/exp/exphook.py +++ b/src/lumo/exp/exphook.py @@ -73,8 +73,8 @@ class Diary(ExpHook): def on_start(self, exp: Experiment, *args, **kwargs): super().on_start(exp, *args, **kwargs) - with open(exp.root_file(f'{strftime("%y%m%d")}.log', 'diary'), 'a') as w: - w.write(f'{strftime("%H:%M:%S")}, {exp.test_root}\n') + # with open(exp.root_file(f'{strftime("%y%m%d")}.log', 'diary'), 'a') as w: + # w.write(f'{strftime("%H:%M:%S")}, {exp.test_root}\n') class RecordAbort(ExpHook): @@ -161,7 +161,7 @@ def on_start(self, exp: Experiment, *args, **kwargs): pass dep_hash = hash(dep_source) - commit_ = git_commit(key='lumo', info=exp.test_root, filter_files=filter_files) + commit_ = git_commit(key='lumo', info=exp.info_dir, filter_files=filter_files) if commit_ is None: exp.dump_info('git', { @@ -176,14 +176,13 @@ def on_start(self, exp: Experiment, *args, **kwargs): 'repo': exp.project_root, 'dep_hash': dep_hash, }) - - file = exp.root_file(hash(exp.project_root), 'repos') + file = exp.mk_rpath('repos', hash(exp.project_root)) exps = {} if os.path.exists(file): exps = io.load_json(file) res = exps.setdefault(exp.project_root, list()) - if exp.exp_root not in res: - res.append(exp.exp_root) + if exp.exp_dir not in res: + res.append(exp.exp_dir) io.dump_json(exps, file) @@ -226,10 +225,6 @@ def on_end(self, exp: Experiment, end_code=0, *args, **kwargs): print('Properties:') indent_print(pformat(exp.properties)) - print('Tags:') - indent_print(pformat(exp.tags)) - print('Use paths:') - indent_print(pformat(exp.paths)) print('Execute:') indent_print(' '.join(exp.exec_argv)) print('-----------------------------------') diff --git a/src/lumo/exp/watch.py b/src/lumo/exp/watch.py index a7ef62c..e2dc867 100644 --- a/src/lumo/exp/watch.py +++ b/src/lumo/exp/watch.py @@ -18,22 +18,145 @@ -> 每次只记录 """ +import numbers import os.path -from typing import List, Dict - +from typing import List, Dict, overload +from pprint import pformat import pandas as pd from dbrecord import PDict +from datetime import datetime +from operator import gt, ge, le, lt, eq, ne -from lumo.proc.path import progressroot, exproot, dbroot +from lumo.proc.path import progressroot, exproot, dbroot, cache_dir from .experiment import Experiment -from .finder import is_test_name from lumo.utils import safe_io as IO -from lumo.analyse.collect import collect_table_rows +from lumo.utils.fmt import format_timedelta, strptime, strftime PID_ROOT = os.path.join(progressroot(), 'pid') HB_ROOT = os.path.join(progressroot(), 'hb') EXP_ROOT = os.path.join(progressroot()) +styles = { + 'row-radio': """""", + 'widget-box': """ + + + """ +} + + +def in_(ser, value): + """pandas operation""" + return ser.apply(lambda x: x in value) + + +def not_in_(ser, value): + """pandas operation""" + return ser.apply(lambda x: x not in value) + + +# supported conditions +mapping = { + '>=': ge, + '<=': le, + '==': eq, + '!=': ne, + '>': gt, + '<': lt, + 'in': in_, + 'notin': not_in_, +} + + +class Condition: + def __init__(self, name: str = None, value=None, op=None): + self.name = name + self.value = value + self.op = op + + def __getattr__(self, item): + return Condition(item) + + def __getitem__(self, item): + return Condition(item) + + def __neg__(self): + self.drop = True + return self + + def __ge__(self, other): + if other is None: + raise AssertionError() + self.value = other + self.op = ">=" + return self + + def __le__(self, other): + if other is None: + raise AssertionError() + self.value = other + self.op = "<=" + return self + + def __eq__(self, other): + self.value = other + self.op = "==" + return self + + def __ne__(self, other): + self.value = other + self.op = "!=" + return self + + def __gt__(self, other): + if other is None: + raise AssertionError() + self.value = other + self.op = ">" + return self + + def __lt__(self, other): + assert other is not None + self.value = other + self.op = "<" + return self + + def __repr__(self): + return f'C({self.name} {self.op} {self.value})' + + def in_(self, lis): + """condition of `in` operation""" + self.op = 'in' + self.value = set(lis) + return self + + def not_in_(self, lis): + """condition of `.duplicated(value) == False` operation""" + self.op = 'notin' + self.value = set(lis) + return self + + def mask(self, df): + names = self.name.split('.') + value = df + for i in names: + if isinstance(value, pd.DataFrame): + value = value[i] + else: + value = df.apply(lambda x: x[i]) + return mapping[self.op](value, self.value) + + def apply(self, df): + return df[self.mask(df)] + + +C = Condition() + class Watcher: """List and watch experiments with time order @@ -47,10 +170,10 @@ def __init__(self, exp_root=None, hb_root=None, pid_root=None, db_root=None): exp_root = os.path.join(exproot(), 'hb') if hb_root is None: - hb_root = os.path.join(progressroot(), 'hb') + hb_root = os.path.join(cache_dir(), 'heartbeat') if pid_root is None: - pid_root = os.path.join(progressroot(), 'pid') + pid_root = os.path.join(cache_dir(), 'pid') if db_root is None: db_root = dbroot() @@ -62,16 +185,18 @@ def __init__(self, exp_root=None, hb_root=None, pid_root=None, db_root=None): def load(self): res = {} updates = {} + if not os.path.exists(self.hb_root): + return pd.DataFrame() for root, dirs, fs in os.walk(self.hb_root): if root == self.hb_root: continue for f in fs: - if f.endswith('heartbeat'): + if f.endswith('hb'): hb_file = os.path.join(root, f) test_root = IO.load_text(hb_file) try: exp = Experiment.from_disk(test_root) - updates.setdefault(exp.exp_name, []).append(exp.dict()) + updates.setdefault(exp.exp_name, []).append(exp.cache()) except KeyboardInterrupt as e: raise e except: @@ -88,18 +213,23 @@ def load(self): dic.flush() df = pd.DataFrame(res.values()) - return df + df = df.sort_values(['exp_name', 'test_name']) + return df.reset_index(drop=True) - def progress(self): + def progress(self, is_alive=True): """return the alive process""" res = [] - for pid in os.listdir(self.pid_root): - try: - test_root = IO.load_text(os.path.join(self.pid_root, pid)) - exp = Experiment.from_disk(test_root) - res.append(exp.dict()) - except: - continue + for root, dirs, fs in os.walk(self.pid_root): + for f in fs: + if not f.endswith('.pid'): + continue + try: + test_root = IO.load_text(os.path.join(root, f)) + exp = Experiment.from_disk(test_root) + if exp.is_alive == is_alive: + res.append(exp.dict()) + except: + continue return pd.DataFrame(res) def interactive(self): @@ -120,3 +250,287 @@ def list_all(self, exp_root=None, limit=100) -> Dict[str, List[Experiment]]: Returns: A dictionary of all experiments, where the keys are the names of the experiments and the values are lists of corresponding Experiment objects. """ + + def widget(self, + is_finished: bool = None, + is_alive: bool = None, + time_filter: list = None, + params_filter: list = None, + metric_filter: list = None + ): + assert params_filter is None or isinstance(params_filter, list) + assert metric_filter is None or isinstance(metric_filter, list) + + from ipywidgets import widgets, interact, Label + from IPython.display import display + + def make_row(dic: dict): + exp = Experiment.from_cache(dic.copy()) + + def on_note_update(sender): + exp.dump_note(sender['new']) + + def on_tag_update(sender): + exp.dump_tags(*sender['new']) + + note_ui = widgets.Textarea(dic['note']) + + note_ui.continuous_update = False + note_ui.observe(on_note_update, names='value', type='change') + + tags = dic.get('tags', []) + try: + tags = list(tags) + except: + tags = [] + tag_ui = widgets.TagsInput(value=tags) + tag_ui.observe(on_tag_update, names='value', type='change') + + now = datetime.now() + start = strptime(datestr=dic['progress']['start']) + end = strptime(datestr=dic['progress']['last_edit_time']) + + human = widgets.VBox([ + + ]) + return [ + widgets.Label(dic['exp_name']), + widgets.Label(dic['test_name']), + widgets.Label(f"""{strftime('%y-%m-%d %H:%M:%S', dateobj=start)}"""), + widgets.Label(f"""{strftime('%y-%m-%d %H:%M:%S', dateobj=end)}"""), + widgets.HTML('\n'.join([ + f'{k}: {v}' + for k, v in dic['metrics'].items() + if isinstance(v, numbers.Number) + ])), + widgets.HBox([note_ui, + tag_ui, ]) + + ] + + test_status = widgets.RadioButtons(options=['full', 'running', 'failed', 'succeed', 'finished']) + start_filter = widgets.DatetimePicker() + end_filter = widgets.DatetimePicker() + + def status_filter(sender): + print(sender) + make() + + test_status.observe(status_filter, names='value', type='change') + + # display() + + @interact + def make( + status=widgets.RadioButtons(options=['full', 'running', 'failed', 'succeed', 'finished']), + start=widgets.DatetimePicker(), + end=widgets.DatetimePicker(), + ): + if status == 'running': + df = self.progress() + elif status == 'finished': + df = self.progress(is_alive=False) + else: + df = self.load() + if status == 'succeed': + df = df[df['progress'].apply(lambda x: x['finished'])] + elif status == 'failed': + df = df[df['exception'].isna() == False] + + if start: + df = df.pipe( + lambda x: x[x['progress'].apply(lambda y: strptime(datestr=y['start'])) > start] + ) + if end: + df = df.pipe( + lambda x: x[x['progress'].apply(lambda y: strptime(datestr=y['end'])) < end] + ) + + if params_filter is not None: + df_params = df['params'] + masks = None + for condition in params_filter: + mask = condition.mask(df_params) + if masks is None: + masks = mask + else: + masks *= mask + df = df[masks] + + if metric_filter is not None: + df_params = df['metrics'] + masks = None + for condition in metric_filter: + mask = condition.mask(df_params) + if masks is None: + masks = mask + else: + masks *= mask + df = df[masks] + + exps = df.to_dict(orient='records') + # grid = widgets.GridspecLayout(len(exps) + 1, 7) + + children = [ + widgets.Label('exp_name'), + widgets.Label('test_name'), + widgets.Label('start'), + widgets.Label('end'), + widgets.Label('metrics'), + widgets.Label('note & tags'), + ] + # grid[0, 0] = widgets.Label('Meta') + # grid[0, 1] = widgets.Label('Metrics') + # grid[0, 2] = widgets.Label('Notes') + for i, exp in enumerate(exps, start=1): + row = make_row(exp) + children.extend(row) + # display(widgets.HBox(row)) + # for j, item in enumerate(row): + # grid[i, j] = item + + grid = widgets.GridBox(children=children, + + layout=widgets.Layout( + width='100%', + grid_template_columns=' '.join(['auto'] * 5) + ' auto', + # grid_template_rows='80px auto 80px', + grid_gap='5px 10px') + ) + display( + widgets.HTML(""" + + """), + grid, + + ) + + # return display( + # widgets.HTML(styles['row-radio']), + # widgets.HTML(""" + # + # """), + # grid, clear=True) + + +class ExperimentWidget: + @overload + def __init__(self, exp_name, test_name, + progress: dict, + params: dict, metrics: dict, note: str, tags: set, exp: Experiment): + pass + + def __init__(self, **kwargs): + from ipywidgets import widgets + self.wid = widgets + self.exp = kwargs.pop('exp') # type: Experiment + self._prop = kwargs + + self._widgets = { + 'exp_name': widgets.HTML(self._prop['exp_name']), + 'test_name': widgets.HTML(self._prop['test_name']), + 'metrics': widgets.VBox( + [widgets.HTML(f'{k}: {v}') for k, v in self._prop['metrics'].items() if + isinstance(v, numbers.Number)]), + } + + self._params_widgets = {} + + note_ui = widgets.Textarea(self._prop['note']) + + note_ui.continuous_update = False + note_ui.observe(self.on_note_update, names='value', type='change') + self._widgets['note'] = note_ui + + tag_ui = widgets.TagsInput(value=list(self._prop['tags'])) + self._widgets['tags'] = tag_ui + tag_ui.observe(self.on_tag_update, names='value', type='change') + + def on_note_update(self, sender): + self.exp.dump_note(sender['new']) + + def on_tag_update(self, sender): + self.exp.dump_tags(*sender['new']) + + def set_key_params(self, keys: list): + self._params_widgets.clear() + for key in keys: + self._params_widgets[key] = self.wid.HTML( + f"""{key}: {pformat(self._prop['params'][key], width=10, indent=2, compact=True)}""") + + def sep(self): + return self.wid.Output(layout={'border': '1px solid black'}) + + def id_flag(self): + return self.wid.VBox([ + self._widgets['exp_name'], + self._widgets['test_name'], + ]) + + def key_params(self): + return self.wid.VBox([ + *self._params_widgets.values() + ]) + + def editable(self): + return self.wid.VBox([ + self._widgets['note'], + self.sep(), + self._widgets['tags'], + ]) + + def time(self): + now = datetime.now() + start = strptime(datestr=self._prop['progress']['start']) + end = strptime(datestr=self._prop['progress']['start']) + return self.wid.VBox([ + self.wid.HTML(f"""Start at: {format_timedelta(now - start)}"""), + self.wid.HTML(f"""End at: {format_timedelta(now - end)}"""), + ]) + + def widget_dict(self): + return { + 'id_flag': self.id_flag(), + 'time': self.time(), + 'editable': self.editable(), + 'params': self.key_params(), + } + + def widget(self): + params = self.key_params() + params = [ + self.sep(), + params, + ] + + hbox = self.wid.HBox([ + self.id_flag(), + self.time(), + self._widgets['metrics'], + self.editable(), + self.key_params(), + ]) + + return hbox + + @classmethod + def from_experiment(cls, exp: Experiment): + tags = exp.properties.get('tags', []) + try: + tags = set(tags) + except: + tags = set() + return cls( + exp_name=exp.exp_name, + test_name=exp.test_name, + progress=exp.properties.get('progress', {}), + params=exp['params'], + metrics=exp.metric.value, + note=exp.properties.get('note', ''), + tags=tags, + exp=exp, + ) diff --git a/src/lumo/trainer/components.py b/src/lumo/trainer/components.py index 8422035..37d5bf7 100644 --- a/src/lumo/trainer/components.py +++ b/src/lumo/trainer/components.py @@ -10,11 +10,11 @@ class TrainerExperiment(SimpleExperiment): @property def log_dir(self): - return self.test_root + return self.info_dir @property def params_fn(self): - res = self.test_file('params.yaml') + res = self.mk_ipath('params.yaml') self.dump_string('params.yaml', res) return res @@ -24,7 +24,7 @@ def board_args(self): if self.has_prop(key): return self.get_prop(key) else: - log_dir = self.test_dir('board') + log_dir = self.mk_ipath('board', is_dir=True) res = { 'filename_suffix': '.bd', 'log_dir': log_dir, @@ -34,7 +34,7 @@ def board_args(self): @property def state_dict_dir(self): - res = self.blob_dir('state_dict') + res = self.mk_bpath('state_dict', is_dir=True) return res def dump_train_eidx(self, eidx, epoch: int): diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 7ac6aed..9564882 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -13,7 +13,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader - +import json from lumo.contrib.accelerate import Accelerator from lumo.contrib.accelerate.utils import send_to_device from lumo.core import TrainStage, Record, MetricType, Meter @@ -85,8 +85,8 @@ def __init__(self, params: ParamsType, dm: DataModule = None): self.params.iparams() self.exp = TrainerExperiment(self.generate_exp_name()) - self._database = TableRow(self.exp.metrics_fn, persistent=self.is_main) - self.metric_board = Metrics(self.exp.test_root, persistent=self.is_main) + self._database = TableRow(self.exp.mk_ipath('metric.pkl'), persistent=self.is_main) + self.metric_board = Metrics(self.exp.mk_bpath('board.sqlite'), persistent=self.is_main) self.metric = self.exp.metric self.exp.dump_info('metric_board', self.metric_board.fpath) @@ -1009,64 +1009,29 @@ def wait_for_everyone(self): """ self.accelerate.wait_for_everyone() - def save_model(self, is_best=False, meta_info: Union[str, dict] = None): - """ - Saves the current model. - - Args: - is_best (bool, optional): Indicates whether the current model is the best one so far. Defaults to False. - meta_info (Union[str, dict], optional): Additional information to include in the saved model file. Can be a - string, a dictionary, or a Meter object. Defaults to None. + def save_best_model(self): + if self.is_main: + file = self.exp.mk_bpath('models', 'best_model.ckpt') + file_info = self.exp.mk_bpath('models', 'best_model.json') + else: + file = self.exp.mk_bpath('models', f'best_model-{self.local_rank}.ckpt') + file_info = self.exp.mk_bpath('models', f'best_model-{self.local_rank}.json') + torch.save(self.state_dict(), file) - Returns: - str: The path to the saved model file. - """ - info = self._build_trainer_meta_info(meta_info) - val = self.saver.save_model(self.eidx, self.model_state_dict(), - meta_info=info, - is_best=is_best) + with open(file_info, 'w') as w: + w.write(json.dumps({'global_steps': self.global_steps, 'metric': self.exp.metric.value})) + self.logger.info(f'saved best model at {file}') self.wait_for_everyone() - return val - - def _build_trainer_meta_info(self, meta_info: Union[str, dict] = None): - """ - Builds a dictionary containing metadata about the Trainer object. - - Args: - meta_info (Union[str, dict], optional): Additional metadata to include in the dictionary. Can be a string, a - dictionary, or a Meter object. Defaults to None. - Returns: - dict: A dictionary containing metadata about the Trainer object. - """ - info = dict() - info['eidx'] = self.eidx - if meta_info is not None: - if isinstance(meta_info, str): - info['msg'] = meta_info - if isinstance(meta_info, Meter): - meta_info = meta_info.serialize() - if isinstance(meta_info, dict): - info.update(meta_info) - return info - - def save_checkpoint(self, max_keep=10, is_best=False, meta_info: Union[str, dict, Meter] = None): - """ - Saves a checkpoint of the current state of the Trainer object. - - Args: - max_keep (int, optional): The maximum number of checkpoints to keep. Defaults to 10. - is_best (bool, optional): Indicates whether the current checkpoint is the best one so far. Defaults to False. - meta_info (Union[str, dict, Meter], optional): Additional information to include in the saved checkpoint file. - Can be a string, a dictionary, or a Meter object. Defaults to None. - - Returns: - str: The path to the saved checkpoint file. - """ - info = self._build_trainer_meta_info(meta_info) - val = self.saver.save_checkpoint(self.eidx, self.state_dict(), - meta_info=info, - max_keep=max_keep, - is_best=is_best) + def save_last_model(self): + if self.is_main: + file = self.exp.mk_bpath('models', 'last_model.ckpt') + file_info = self.exp.mk_bpath('models', 'last_model.json') + else: + file = self.exp.mk_bpath('models', f'last_model-{self.local_rank}.ckpt') + file_info = self.exp.mk_bpath('models', f'last_model-{self.local_rank}.json') + torch.save(self.state_dict(), file) + with open(file_info, 'w') as w: + w.write(json.dumps({'global_steps': self.global_steps, 'metric': self.exp.metric.value})) + self.logger.info(f'saved last model at {file}') self.wait_for_everyone() - return val diff --git a/src/lumo/utils/fmt.py b/src/lumo/utils/fmt.py index 5c8458c..7ba7897 100644 --- a/src/lumo/utils/fmt.py +++ b/src/lumo/utils/fmt.py @@ -2,7 +2,7 @@ Format date/filename, check array shape, convert item from torch.Tensor to ndarray or scalar. """ import textwrap -from datetime import datetime +from datetime import datetime, timedelta import numpy as np import torch @@ -45,7 +45,7 @@ def strftime(fmt='%y-%m-%d-%H%M%S', dateobj: datetime = None): return datetime.now().strftime(fmt) -def strptime(fmt='%y-%m-%d-%H%M%S', datestr: str = None): +def strptime(datestr: str = None, fmt='%y-%m-%d-%H%M%S', ): """Convert a string to a datetime object using the specified format.""" return datetime.strptime(datestr, fmt) @@ -78,9 +78,17 @@ def format_second(sec: int) -> str: min, sec = divmod(sec, 60) if min > 60: hour, min = divmod(min, 60) - fmt = "{}h{}m{}s".format(hour, min, int(sec)) + if hour > 24: + day, hour = divmod(hour, 24) + fmt = "{}d{}h{}m{}s".format(int(day), int(hour), int(min), int(sec)) + else: + fmt = "{}h{}m{}s".format(int(hour), int(min), int(sec)) else: - fmt = "{}m{}s".format(min, int(sec)) + fmt = "{}m{}s".format(int(min), int(sec)) else: fmt = "{}s".format(int(sec)) return fmt + + +def format_timedelta(td: timedelta): + return format_second(td.total_seconds()) diff --git a/src/lumo/utils/repository.py b/src/lumo/utils/repository.py index 73cd33d..6466ab7 100644 --- a/src/lumo/utils/repository.py +++ b/src/lumo/utils/repository.py @@ -202,7 +202,7 @@ def git_archive(repo=None, commit_hex=None, commit: Commit = None): old_path = os.getcwd() os.chdir(commit.tree.abspath) exp = Experiment('GitArchive') - fn = exp.blob_file(f'{commit.hexsha[:8]}.tar') + fn = exp.mk_bpath(f'{commit.hexsha[:8]}.tar') exp.dump_info('git_archive', {'file': fn, 'test_name': exp.test_name, diff --git a/tests/exp/test_watcher.py b/tests/exp/test_watcher.py index e33474e..35ad087 100644 --- a/tests/exp/test_watcher.py +++ b/tests/exp/test_watcher.py @@ -18,7 +18,7 @@ def test_exp(): for i in range(10): t = trainer() t.train() - print(t.exp.test_name) + print(t.exp.heartbeat_fn) w = Watcher() df = w.load() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7df8fe..38c015c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -156,9 +156,9 @@ def test_trainer(): # test trainer experiment exp = trainer.exp - assert exp.exp_root == os.path.join(glob['exp_root'], trainer.generate_exp_name()) - assert exp.lib_root == glob['home'] - assert exp.blob_root == os.path.join(glob['blob_root'], trainer.generate_exp_name(), exp.test_name) + assert exp.exp_dir == os.path.join(glob['exp_root'], trainer.generate_exp_name()) + assert exp.info_dir == os.path.join(glob['exp_root'], trainer.generate_exp_name(), exp.test_name) + assert exp.blob_dir == os.path.join(glob['blob_root'], trainer.generate_exp_name(), exp.test_name) assert exp.project_root == git_dir() # how to test writer? _ = trainer.safe_writer From cd94b8bdcad46cd3eb24a4232e229fb87b000e6d Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 20:58:34 +0800 Subject: [PATCH 94/99] Update function args --- src/lumo/trainer/callbacks.py | 50 ++++++----------------------------- 1 file changed, 8 insertions(+), 42 deletions(-) diff --git a/src/lumo/trainer/callbacks.py b/src/lumo/trainer/callbacks.py index 36c7096..2f71b22 100644 --- a/src/lumo/trainer/callbacks.py +++ b/src/lumo/trainer/callbacks.py @@ -522,7 +522,7 @@ class EMAUpdate(TrainCallback): - name is started with 'ema' """ - def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType, *args, **kwargs): + def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs): super().on_train_step_end(trainer, func, params, metric, *args, **kwargs) for k, v in trainer.model_dict.items(): if k.lower().startswith('ema'): @@ -608,7 +608,7 @@ def on_hooked(self, source: Trainer, params: ParamsType): super().on_hooked(source, params) source.exp.dump_string('AutoRecord', self.__class__.__name__) - def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType, *args, **kwargs): + def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs): super().on_train_step_end(trainer, func, params, metric, *args, **kwargs) self.log(metric, step=trainer.global_steps, namespace='train.step') @@ -616,50 +616,15 @@ def on_train_epoch_end(self, trainer: Trainer, func, params: ParamsType, record: super().on_train_epoch_end(trainer, func, params, record, *args, **kwargs) self.log(record.agg(), step=trainer.global_steps, namespace='train.epoch') - def on_test_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): + def on_test_end(self, trainer: Trainer, func, params: ParamsType, record: Record = None, *args, **kwargs): super().on_test_end(trainer, func, params, record, *args, **kwargs) self.log(record.agg(), step=trainer.global_steps, namespace='test') - def on_eval_end(self, trainer: Trainer, func, params: ParamsType, record: Record, *args, **kwargs): + def on_eval_end(self, trainer: Trainer, func, params: ParamsType, record: Record = None, *args, **kwargs): super().on_eval_end(trainer, func, params, record, *args, **kwargs) self.log(record.agg(), step=trainer.global_steps, namespace='evaluate') -class WandbCallback(RecordCallback): - only_main_process = True - - def __init__(self, metric_step=500) -> None: - super().__init__() - self.metric_step = metric_step - self.c = 0 - - def log(self, metrics: MetricType, step, namespace): - self.c += 1 - if self.c % self.metric_step == 0: - metrics = { - f"{namespace}.{k}": v - for k, v in wrap_result(metrics).items()} - self._hooked.wandb.log(metrics, step=step) - - def log_text(self, metrics: Dict, step: int, namespace: str): - wandb = self._hooked.wandb - metrics = {k: v for k, v in metrics.items()} - wandb.log(metrics, step=step) - - def log_scalars(self, metrics: Dict, step: int, namespace: str): - wandb = self._hooked.wandb - metrics = {k: wandb.Html(v) for k, v in metrics.items()} - wandb.log(metrics, step=step) - - def log_matrix(self, metrics: Dict, step: int, namespace: str): - wandb = self._hooked.wandb - metrics = {k: wandb.Image(v) for k, v in metrics.items()} - wandb.log(metrics, step=step) - - def on_first_exception(self, source: Trainer, func, params: ParamsType, e: BaseException, *args, **kwargs): - super().on_first_exception(source, func, params, e, *args, **kwargs) - - class TensorBoardCallback(RecordCallback): only_main_process = True @@ -690,7 +655,8 @@ class StopByCode(TrainCallback): def __init__(self, step=100): self.step = step - def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: Meter, *args, **kwargs): + def on_train_step_end(self, trainer: Trainer, func, params: ParamsType, metric: MetricType = None, *args, **kwargs): + super().on_train_step_end(trainer, func, params, metric, *args, **kwargs) if trainer.global_steps % self.step == 0: if os.path.exists(trainer.exp.test_file('.stop')): trainer.exp.add_tag('lumo.early_stop') @@ -840,7 +806,7 @@ def on_hooked(self, source: Trainer, params: ParamsType): super().on_hooked(source, params) from dbrecord import PDict from lumo.exp.finder import is_test_root - self.fn = source.exp.exp_file('params_key.sqlite') + self.fn = source.exp.mk_rpath('contrib', 'params_key.sqlite') olds = PDict(self.fn) current = source.params.hash() @@ -859,6 +825,6 @@ def on_train_end(self, trainer: Trainer, func, params: ParamsType, record: Recor super().on_train_end(trainer, func, params, record, *args, **kwargs) from dbrecord import PDict olds = PDict(self.fn) - olds[params.hash()] = trainer.exp.test_root + olds[params.hash()] = trainer.exp.info_dir olds.flush() trainer.logger.info(f'Save current params ({params.hash()}) to {self.fn}') From c0bd292bf499b6d2ad600624c42cf4c2b17f7117 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 20:59:40 +0800 Subject: [PATCH 95/99] Weakly update, add minor version for the reconstruction. --- src/lumo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lumo/__init__.py b/src/lumo/__init__.py index 957eb07..5893105 100644 --- a/src/lumo/__init__.py +++ b/src/lumo/__init__.py @@ -1,7 +1,7 @@ """ """ -__version__ = "0.14.6" +__version__ = "0.15.0" from .core import Params, ParamsType, MetricType, Meter, Record, TrainStage, BaseParams From 042298d2d9460833bf699aba734acaca4a8d7d42 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 21:06:33 +0800 Subject: [PATCH 96/99] Ignore paths for old version --- src/lumo/exp/experiment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index 2a904bf..cb4587d 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -750,7 +750,11 @@ def from_disk(cls, path): exp_dir = os.path.dirname(path) paths_fn = os.path.join(path, 'info', f'paths.json') - paths = io.load_json(paths_fn) + try: + paths = io.load_json(paths_fn) + except ValueError as e: + paths = {} + self = cls(os.path.basename(exp_dir), test_name=os.path.basename(path), paths=paths) # load prop From 8df5c35fef36a84fe15c8eedbb9bce0e68204e73 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 21:06:45 +0800 Subject: [PATCH 97/99] remove filelock for better choice --- src/lumo/sketch/filelock.py | 43 ------------------------------------- 1 file changed, 43 deletions(-) delete mode 100644 src/lumo/sketch/filelock.py diff --git a/src/lumo/sketch/filelock.py b/src/lumo/sketch/filelock.py deleted file mode 100644 index 3524463..0000000 --- a/src/lumo/sketch/filelock.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import warnings - -CAN_USE_LOCK = True -if os.name == 'nt': - import win32con, win32file, pywintypes - - LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK - LOCK_SH = 0 # The default value - LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY - __overlapped = pywintypes.OVERLAPPED() - - - def lock(file, flags): - hfile = win32file._get_osfhandle(file.fileno()) - win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped) - - - def unlock(file): - hfile = win32file._get_osfhandle(file.fileno()) - win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped) -elif os.name == 'posix': - import fcntl - - - def lock(file, flags=fcntl.LOCK_EX): - fcntl.flock(file.fileno(), flags) - - - def unlock(file): - fcntl.flock(file.fileno(), fcntl.LOCK_UN) -else: - CAN_USE_LOCK = False - - warnings.warn(f'You are in an unknown platform {os.name}, filelock may cant be used.') - - - def lock(file, flat): - raise NotImplementedError(f'an UNKNOWN platform {os.name}') - - - def unlock(file): - fcntl.flock(file.fileno(), fcntl.LOCK_UN) From 1bbfe63a7bdc82242d2bc4cf2a6f3a0ad0368cb7 Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 21:08:33 +0800 Subject: [PATCH 98/99] Update version --- src/lumo/exp/experiment.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lumo/exp/experiment.py b/src/lumo/exp/experiment.py index cb4587d..358e9c3 100644 --- a/src/lumo/exp/experiment.py +++ b/src/lumo/exp/experiment.py @@ -750,9 +750,12 @@ def from_disk(cls, path): exp_dir = os.path.dirname(path) paths_fn = os.path.join(path, 'info', f'paths.json') - try: - paths = io.load_json(paths_fn) - except ValueError as e: + if os.path.exists(paths_fn): + try: + paths = io.load_json(paths_fn) + except ValueError as e: + paths = {} + else: paths = {} self = cls(os.path.basename(exp_dir), test_name=os.path.basename(path), paths=paths) From 6772eaf1d14589adc81036fab091cfd860ff352c Mon Sep 17 00:00:00 2001 From: sailist Date: Sun, 12 Mar 2023 22:37:17 +0800 Subject: [PATCH 99/99] Update version --- src/lumo/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lumo/trainer/trainer.py b/src/lumo/trainer/trainer.py index 9564882..27338f9 100644 --- a/src/lumo/trainer/trainer.py +++ b/src/lumo/trainer/trainer.py @@ -31,6 +31,8 @@ # overwrite send_to_device to resolve https://github.com/pytorch/pytorch/issues/83015 # from accelerate import Accelerator # from accelerate.utils import send_to_device +from ..utils.fmt import strftime + ParamsType = TrainerParams @@ -514,7 +516,7 @@ def to_device(self, item: Optional[Union[nn.Module, torch.Tensor, Sequence, Mapp def on_trainer_exception(self, func: Callable, exception: BaseException): """Updates database with error information when an exception occurs during training.""" - self.exp.dump_info('exception', dict(end=datetime.now(), + self.exp.dump_info('exception', dict(end=strftime(), finished=False, error=str(exception), trainer_frame=str(func)))