diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..e92b319 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,11 @@ +name: Linters + +on: [pull_request] + +jobs: + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + - uses: psf/black@stable diff --git a/setup.py b/setup.py index a68a380..9b9c893 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,14 @@ from setuptools import setup -setup(name='smart_settings', - version='1.1', - description='Smart JSON setting files', - url='https://github.com/mrolinek', - author='Michal Rolinek, MPI-IS Tuebingen, Autonomous Learning', - author_email='michalrolinek@gmail.com', - license='MIT', - packages=['smart_settings'], - install_requires=['pyyaml'], - zip_safe=False) +setup( + name="smart_settings", + version="1.1", + description="Smart JSON setting files", + url="https://github.com/mrolinek", + author="Michal Rolinek, MPI-IS Tuebingen, Autonomous Learning", + author_email="michalrolinek@gmail.com", + license="MIT", + packages=["smart_settings"], + install_requires=["pyyaml"], + zip_safe=False, +) diff --git a/smart_settings/__init__.py b/smart_settings/__init__.py index c978933..96fe172 100644 --- a/smart_settings/__init__.py +++ b/smart_settings/__init__.py @@ -1,4 +1,4 @@ from .smart_settings import load, loads from .file_editing import add_key, change_key_name -__all__ = ['load', 'loads'] +__all__ = ["load", "loads"] diff --git a/smart_settings/add_key.py b/smart_settings/add_key.py index e9a32ce..28669cd 100644 --- a/smart_settings/add_key.py +++ b/smart_settings/add_key.py @@ -2,14 +2,16 @@ import argparse import glob -parser = argparse.ArgumentParser(description='Adding a key to many JSON files.') -parser.add_argument(dest='files', type=str, help='Input file(s)') -parser.add_argument(dest='key_name', type=str, help='Name of the key to add') -parser.add_argument(dest='default_value', type=str, help='Default value to assign') -parser.add_argument('--override', action='store_true', help='Override value if key present') +parser = argparse.ArgumentParser(description="Adding a key to many JSON files.") +parser.add_argument(dest="files", type=str, help="Input file(s)") +parser.add_argument(dest="key_name", type=str, help="Name of the key to add") +parser.add_argument(dest="default_value", type=str, help="Default value to assign") +parser.add_argument( + "--override", action="store_true", help="Override value if key present" +) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() matching_files = glob.glob(args.files, recursive=True) @@ -17,9 +19,13 @@ print(f"No files matching {args.files}") exit(2) - *prefixes, new_key = args.key_name.split('.') + *prefixes, new_key = args.key_name.split(".") for setting_file in matching_files: - add_key(setting_file, prefixes=prefixes, new_key=new_key, - default_value=eval(args.default_value), override=args.override, conditions=None) - - + add_key( + setting_file, + prefixes=prefixes, + new_key=new_key, + default_value=eval(args.default_value), + override=args.override, + conditions=None, + ) diff --git a/smart_settings/change_key.py b/smart_settings/change_key.py index 26aa448..4e8cfb9 100644 --- a/smart_settings/change_key.py +++ b/smart_settings/change_key.py @@ -2,14 +2,24 @@ import argparse import glob -parser = argparse.ArgumentParser(description='Changing a key name in JSON file(s).') -parser.add_argument(dest='files', type=str, help='Input file(s)') -parser.add_argument(dest='key_name', type=str, help='Name of the key to change (with \'.\' syntax for nesting)') -parser.add_argument(dest='new_name', type=str, help='New name of the key (only the suffix -- no \'.\' expected)') -parser.add_argument('--override', action='store_true', help='Override value if key present') +parser = argparse.ArgumentParser(description="Changing a key name in JSON file(s).") +parser.add_argument(dest="files", type=str, help="Input file(s)") +parser.add_argument( + dest="key_name", + type=str, + help="Name of the key to change (with '.' syntax for nesting)", +) +parser.add_argument( + dest="new_name", + type=str, + help="New name of the key (only the suffix -- no '.' expected)", +) +parser.add_argument( + "--override", action="store_true", help="Override value if key present" +) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() matching_files = glob.glob(args.files, recursive=True) @@ -17,8 +27,12 @@ print(f"No files matching {args.files}") exit(2) - *prefixes, old_key = args.key_name.split('.') + *prefixes, old_key = args.key_name.split(".") for setting_file in matching_files: - change_key_name(setting_file, prefixes=prefixes, old_name=old_key, new_name=args.new_name, conditions=None) - - + change_key_name( + setting_file, + prefixes=prefixes, + old_name=old_key, + new_name=args.new_name, + conditions=None, + ) diff --git a/smart_settings/dynamic.py b/smart_settings/dynamic.py index 1ae6656..5868350 100644 --- a/smart_settings/dynamic.py +++ b/smart_settings/dynamic.py @@ -3,34 +3,32 @@ def _replace_inside_brackets(string, replace_from, replace_to): - index = string.find('{') + index = string.find("{") if index == -1: yield string return - yield string[:index + 1] - new_index = string[index + 1:].find('}') + yield string[: index + 1] + new_index = string[index + 1 :].find("}") if new_index == -1: yield from string[index:] return new_index += index + 1 - yield string[index + 1:new_index].replace(replace_from, replace_to) + yield string[index + 1 : new_index].replace(replace_from, replace_to) yield from _replace_inside_brackets(string[new_index:], replace_from, replace_to) def replace_inside_brackets(string, replace_from, replace_to): - return ''.join(list(_replace_inside_brackets( - string, replace_from, replace_to))) + return "".join(list(_replace_inside_brackets(string, replace_from, replace_to))) def fstring_in_json(format_string, namespace): if not isinstance(format_string, str): return format_string - replaced_dollar_signs = replace_inside_brackets(format_string, '$', 'ENV_') - env_dict = {'ENV_' + key: value for key, value in os.environ.items()} + replaced_dollar_signs = replace_inside_brackets(format_string, "$", "ENV_") + env_dict = {"ENV_" + key: value for key, value in os.environ.items()} try: - formatted = eval('f\"' + replaced_dollar_signs + - '\"', {**env_dict, **namespace}) + formatted = eval('f"' + replaced_dollar_signs + '"', {**env_dict, **namespace}) except BaseException as e: return format_string diff --git a/smart_settings/file_editing.py b/smart_settings/file_editing.py index d7bb32a..bf717df 100644 --- a/smart_settings/file_editing.py +++ b/smart_settings/file_editing.py @@ -29,7 +29,9 @@ def change_key_name(setting_file, prefixes, old_name, new_name, conditions=None) if old_name in as_dict: old_val = as_dict[old_name] as_dict[new_name] = old_val - print(f"{setting_file}: Key {'.'.join(prefixes+[old_name])} renamed to {'.'.join(prefixes+[new_name])}") + print( + f"{setting_file}: Key {'.'.join(prefixes+[old_name])} renamed to {'.'.join(prefixes+[new_name])}" + ) del as_dict[old_name] with open(setting_file, "w") as f: json.dump(orig_dict, f, indent=4) @@ -40,13 +42,16 @@ def change_key_name(setting_file, prefixes, old_name, new_name, conditions=None) def check_correct_conditions(dct, conditions): for lhs, rhs in conditions.items(): - if not rhs == dct.get(lhs, float('nan')): # using that 'nan' is not equal to ANYTHING in Python - + if not rhs == dct.get( + lhs, float("nan") + ): # using that 'nan' is not equal to ANYTHING in Python return False return True -def add_key(setting_file, prefixes, new_key, default_value, override=False, conditions=None): +def add_key( + setting_file, prefixes, new_key, default_value, override=False, conditions=None +): conditions = conditions or {} with open(setting_file) as f: @@ -68,7 +73,9 @@ def add_key(setting_file, prefixes, new_key, default_value, override=False, cond if new_key in as_dict: if override: - print(f"{setting_file}: Overwrote {'.'.join(prefixes+[new_key])}={default_value} (from {as_dict[new_key]})") + print( + f"{setting_file}: Overwrote {'.'.join(prefixes+[new_key])}={default_value} (from {as_dict[new_key]})" + ) as_dict[new_key] = default_value else: print(f"{setting_file}: Key {new_key} already present, not overwriting...") @@ -80,4 +87,3 @@ def add_key(setting_file, prefixes, new_key, default_value, override=False, cond with open(setting_file, "w") as f: json.dump(orig_dict, f, indent=4) return True - diff --git a/smart_settings/param_classes.py b/smart_settings/param_classes.py index da2c33e..c9059c2 100644 --- a/smart_settings/param_classes.py +++ b/smart_settings/param_classes.py @@ -5,22 +5,21 @@ class NoDuplicateDict(dict): - """ A dict with prohibiting init from a list of pairs containing duplicates""" + """A dict with prohibiting init from a list of pairs containing duplicates""" def __init__(self, *args, **kwargs): if args and args[0] and not isinstance(args[0], dict): keys, _ = zip(*args[0]) duplicates = [ - item for item, - count in collections.Counter(keys).items() if count > 1] + item for item, count in collections.Counter(keys).items() if count > 1 + ] if duplicates: - raise TypeError( - "Keys {} repeated in json parsing".format(duplicates)) + raise TypeError("Keys {} repeated in json parsing".format(duplicates)) super().__init__(*args, **kwargs) class AttributeDict(dict): - """ A dict which allows attribute access to its keys.""" + """A dict which allows attribute access to its keys.""" def __getattr__(self, *args, **kwargs): try: @@ -29,9 +28,10 @@ def __getattr__(self, *args, **kwargs): raise AttributeError(e) def __deepcopy__(self, memo): - """ In order to support deepcopy""" + """In order to support deepcopy""" return self.__class__( - [(deepcopy(k, memo=memo), deepcopy(v, memo=memo)) for k, v in self.items()]) + [(deepcopy(k, memo=memo), deepcopy(v, memo=memo)) for k, v in self.items()] + ) def __setattr__(self, key, value): self.__setitem__(key, value) @@ -44,7 +44,7 @@ def __repr__(self): class ImmutableAttributeDict(AttributeDict): - """ A dict which allows attribute access to its keys. Forced immutable.""" + """A dict which allows attribute access to its keys. Forced immutable.""" def __delattr__(self, item): raise TypeError("Setting object not mutable after settings are fixed!") @@ -89,13 +89,13 @@ def update_recursive(d, u, overwrite=False): if isinstance(v, collections.abc.Mapping): d[k] = update_recursive(d.get(k, {}), v, overwrite) if isinstance(v, collections.abc.Sequence): - raw_key = removesuffix(k, '*') + raw_key = removesuffix(k, "*") if raw_key + "*" in d: # append - d[raw_key + "*"] = deepcopy(v + d[raw_key + '*']) + d[raw_key + "*"] = deepcopy(v + d[raw_key + "*"]) elif raw_key in d: # keep original list pass else: # key does not exist yet, append - d[k] = v + d[k] = v elif k not in d or overwrite: d[k] = v return d diff --git a/smart_settings/port_cluster_utils_3.py b/smart_settings/port_cluster_utils_3.py index d1b9525..f0eeb47 100644 --- a/smart_settings/port_cluster_utils_3.py +++ b/smart_settings/port_cluster_utils_3.py @@ -2,16 +2,25 @@ import argparse import glob -parser = argparse.ArgumentParser(description='Porting JSON file(s) to Cluster utils >=3.0') -parser.add_argument(dest='files', type=str, help='Input file(s)', nargs='+') +parser = argparse.ArgumentParser( + description="Porting JSON file(s) to Cluster utils >=3.0" +) +parser.add_argument(dest="files", type=str, help="Input file(s)", nargs="+") -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() for setting_file in args.files: - change_key_name(setting_file, prefixes=[], old_name='model_dir', new_name='working_dir') - change_key_name(setting_file, prefixes=[], old_name='default_json', new_name='__import__') - change_key_name(setting_file, prefixes=['fixed_params'], old_name='default_json', new_name='__import_promise__') - - + change_key_name( + setting_file, prefixes=[], old_name="model_dir", new_name="working_dir" + ) + change_key_name( + setting_file, prefixes=[], old_name="default_json", new_name="__import__" + ) + change_key_name( + setting_file, + prefixes=["fixed_params"], + old_name="default_json", + new_name="__import_promise__", + ) diff --git a/smart_settings/smart_settings.py b/smart_settings/smart_settings.py index a9b6840..ec5727f 100644 --- a/smart_settings/smart_settings.py +++ b/smart_settings/smart_settings.py @@ -6,25 +6,30 @@ import yaml from .utils import removesuffix -IMPORT_KEY = '__import__' +IMPORT_KEY = "__import__" def load_raw_dict_from_file(filename): - """ Safe load of a json file (doubled entries raise exception)""" - if filename.endswith('.json'): - with open(filename, 'r') as f: + """Safe load of a json file (doubled entries raise exception)""" + if filename.endswith(".json"): + with open(filename, "r") as f: data = json.load(f, object_pairs_hook=NoDuplicateDict) return data - elif filename.endswith('.yaml'): - with open(filename, 'r') as f: + elif filename.endswith(".yaml"): + with open(filename, "r") as f: data = yaml.safe_load(f) return data - -def load(filename, dynamic=True, make_immutable=True, recursive_imports=True, - pre_unpack_hooks=None, post_unpack_hooks=None): - """ Read from a bytestream and deserialize to a settings object""" +def load( + filename, + dynamic=True, + make_immutable=True, + recursive_imports=True, + pre_unpack_hooks=None, + post_unpack_hooks=None, +): + """Read from a bytestream and deserialize to a settings object""" pre_unpack_hooks = pre_unpack_hooks or [] post_unpack_hooks = post_unpack_hooks or [] orig_json = load_raw_dict_from_file(filename) @@ -34,16 +39,21 @@ def load(filename, dynamic=True, make_immutable=True, recursive_imports=True, if recursive_imports: unpack_imports_full( - orig_json, - import_string=IMPORT_KEY, - used_filenames=[filename]) - return _post_load(orig_json, dynamic, make_immutable, - post_unpack_hooks) - - -def loads(s, *, dynamic=True, make_immutable=False, recursive_imports=True, - pre_unpack_hooks=None, post_unpack_hooks=None): - """ Deserialize string to a settings object""" + orig_json, import_string=IMPORT_KEY, used_filenames=[filename] + ) + return _post_load(orig_json, dynamic, make_immutable, post_unpack_hooks) + + +def loads( + s, + *, + dynamic=True, + make_immutable=False, + recursive_imports=True, + pre_unpack_hooks=None, + post_unpack_hooks=None, +): + """Deserialize string to a settings object""" pre_unpack_hooks = pre_unpack_hooks or [] post_unpack_hooks = post_unpack_hooks or [] @@ -59,23 +69,24 @@ def loads(s, *, dynamic=True, make_immutable=False, recursive_imports=True, hook(orig_dict) if recursive_imports: - unpack_imports_full( - orig_dict, - import_string=IMPORT_KEY, - used_filenames=[]) + unpack_imports_full(orig_dict, import_string=IMPORT_KEY, used_filenames=[]) return _post_load(orig_dict, dynamic, make_immutable, post_unpack_hooks) def _post_load(current_dict, dynamic, make_immutable, post_unpack_hooks): - keys = list(current_dict.keys()) # to avoid that list of keys gets updated during loop + keys = list( + current_dict.keys() + ) # to avoid that list of keys gets updated during loop for key in keys: - if key.endswith("*") and isinstance(current_dict[key], collections.abc.Sequence): + if key.endswith("*") and isinstance( + current_dict[key], collections.abc.Sequence + ): raw_key = removesuffix(key, "*") current_dict[raw_key] = current_dict.pop(key) if dynamic: objectified = recursive_objectify(current_dict, make_immutable=False) - timestamp = datetime.now().strftime('%H:%M:%S-%d%h%y') + timestamp = datetime.now().strftime("%H:%M:%S-%d%h%y") namespace = dict(__timestamp__=timestamp, **objectified) recursive_dynamic_json(current_dict, namespace) @@ -107,29 +118,31 @@ def unpack_imports_fixed_level(orig_dict, import_string, used_filenames): """ if import_string in orig_dict: - new_files = orig_dict[import_string] # type(orig_dict[import_string]) in [str, list] + new_files = orig_dict[ + import_string + ] # type(orig_dict[import_string]) in [str, list] if isinstance(new_files, str): new_files = [new_files] del orig_dict[import_string] for new_file in reversed(new_files): if new_file in used_filenames: raise ValueError( - f"Cyclic dependency of JSONs, {new_file} already unpacked") + f"Cyclic dependency of JSONs, {new_file} already unpacked" + ) loaded_dict = load_raw_dict_from_file(new_file) - unpack_imports_full( - loaded_dict, - import_string, - used_filenames + - [new_file]) + unpack_imports_full(loaded_dict, import_string, used_filenames + [new_file]) update_recursive(orig_dict, loaded_dict, overwrite=False) -if __name__ == '__main__': +if __name__ == "__main__": import sys def check_import_in_fixed_params(setting_dict): if "fixed_params" in setting_dict: - if "__import__" in setting_dict['fixed_params']: - raise ImportError("Cannot import inside fixed params. Did you mean __import_promise__?") + if "__import__" in setting_dict["fixed_params"]: + raise ImportError( + "Cannot import inside fixed params. Did you mean __import_promise__?" + ) + params = load(sys.argv[1], pre_unpack_hooks=[check_import_in_fixed_params]) print(params) diff --git a/smart_settings/utils.py b/smart_settings/utils.py index fc1d232..4f80040 100644 --- a/smart_settings/utils.py +++ b/smart_settings/utils.py @@ -1,6 +1,6 @@ def removesuffix(self: str, suffix: str) -> str: # suffix='' should not call self[:-0]. if suffix and self.endswith(suffix): - return self[:-len(suffix)] + return self[: -len(suffix)] else: return self[:]