From faeae97b0a3978290952f36ee3e0068ecd6104c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=BCller?= Date: Wed, 14 Aug 2024 19:00:25 +0200 Subject: [PATCH] introduce TOML configuration file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marcel Müller --- .pre-commit-config.yaml | 1 + mlmgen.toml | 27 ++++++++++++ pyproject.toml | 3 +- src/mlmgen/__version__.py | 4 +- src/mlmgen/cli/cli_parser.py | 32 +++++++++++--- src/mlmgen/cli/entrypoint.py | 84 ++++++++++++++++++++++++++++++++++-- src/mlmgen/generator/main.py | 19 ++++---- 7 files changed, 147 insertions(+), 23 deletions(-) create mode 100644 mlmgen.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b99f7a2..e80708b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,5 +31,6 @@ repos: rev: v1.11.1 hooks: - id: mypy + additional_dependencies: [types-toml] default_language_version: python: python3.12 diff --git a/mlmgen.toml b/mlmgen.toml new file mode 100644 index 0000000..33ae8f0 --- /dev/null +++ b/mlmgen.toml @@ -0,0 +1,27 @@ +# Default configuration for the 'MindLess Molecule GENerator' (MLMGen) +# Following file locations are searched for in the following order: +# 1. Location specified by the `--config < str | Path >` command-line argument +# 2. Current working directory (`Path.cwd()`) +# 3. User's home directory (`Path.home()`) + +[general] +# Verbosity level defining the printout: Options: 0 = silent, 1 = default, 2 = verbose +verbosity = 1 + +# Quantum Mechanics (QM) engine to use. Options: 'xtb', 'orca' +engine = "xtb" + +# Maximum number of optimization cycles. Options: +max_cycles = 100 + +[xtb] +# TODO +# Specific configurations for the XTB engine (if needed) +# xtb_option_1 = "value1" +# xtb_option_2 = "value2" + +[orca] +# TODO +# Specific configurations for the ORCA engine (if needed) +# orca_option_1 = "value1" +# orca_option_2 = "value2" diff --git a/pyproject.toml b/pyproject.toml index 10de0a3..b9e3790 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Topic :: Scientific/Engineering", "Typing :: Typed", ] -dependencies = ["numpy", "networkx"] +dependencies = ["numpy", "networkx", "toml"] dynamic = ["version"] [project.optional-dependencies] @@ -33,6 +33,7 @@ dev = [ "pytest", "tox", "setuptools_scm>=8", + "types-toml", ] [project.scripts] diff --git a/src/mlmgen/__version__.py b/src/mlmgen/__version__.py index cfd9df7..2d857bf 100644 --- a/src/mlmgen/__version__.py +++ b/src/mlmgen/__version__.py @@ -13,5 +13,5 @@ __version_tuple__: VERSION_TUPLE version_tuple: VERSION_TUPLE -__version__ = version = "0.1.1.dev21+g848a76c.d20240814" -__version_tuple__ = version_tuple = (0, 1, 1, "dev21", "g848a76c.d20240814") +__version__ = version = "0.1.1.dev22+g4fb8a78.d20240814" +__version_tuple__ = version_tuple = (0, 1, 1, "dev22", "g4fb8a78.d20240814") diff --git a/src/mlmgen/cli/cli_parser.py b/src/mlmgen/cli/cli_parser.py index 963f2db..efa6ed0 100644 --- a/src/mlmgen/cli/cli_parser.py +++ b/src/mlmgen/cli/cli_parser.py @@ -5,14 +5,19 @@ from ..__version__ import __version__ -def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace: +def cli_parser(argv: Sequence[str] | None = None) -> dict: """ Parse command line arguments. """ # get command line argument parser = argparse.ArgumentParser() + # General arguments parser.add_argument( - "-i", "--input", type=argparse.FileType("r"), help="Input file.", required=False + "-c", + "--config", + type=str, + help="Input file.", + required=False, ) parser.add_argument( "-v", "--version", action="version", version=f"%(prog)s {__version__}" @@ -21,7 +26,6 @@ def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace: "--verbosity", type=int, choices=[0, 1, 2], - default=1, help="Verbosity level (0, 1, or 2).", ) parser.add_argument( @@ -29,17 +33,33 @@ def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace: "--engine", type=str, choices=["xtb", "orca"], - default="xtb", help="QM engine to use.", ) parser.add_argument( "-mc", "--max-cycles", type=int, - default=100, required=False, help="Maximum number of optimization cycles.", ) + # XTB specific arguments + # TODO: Add XTB specific arguments + # ORCA specific arguments + # TODO: Add ORCA specific arguments args = parser.parse_args(argv) + args_dict = vars(args) + + # General arguments + rev_args_dict = {} + rev_args_dict["general"] = { + "config": args_dict["config"], + "verbosity": args_dict["verbosity"], + "engine": args_dict["engine"], + "max_cycles": args_dict["max_cycles"], + } + # XTB specific arguments + rev_args_dict["xtb"] = {} + # ORCA specific arguments + rev_args_dict["orca"] = {} - return args + return rev_args_dict diff --git a/src/mlmgen/cli/entrypoint.py b/src/mlmgen/cli/entrypoint.py index 217d500..8c973c1 100644 --- a/src/mlmgen/cli/entrypoint.py +++ b/src/mlmgen/cli/entrypoint.py @@ -5,6 +5,10 @@ from __future__ import annotations from collections.abc import Sequence +from pathlib import Path +import warnings + +import toml from ..generator import generator from .cli_parser import cli_parser as cl @@ -14,7 +18,81 @@ def console_entry_point(argv: Sequence[str] | None = None) -> int: """ Entrypoint for command line interface. """ + # Step 1: Parse CLI arguments args = cl(argv) - # convert args to dictionary - kwargs = vars(args) - raise SystemExit(generator(kwargs)) + + # Generate a default config that corresponds to this configuration file + # after parsing the toml file into a dictionary + DEFAULT_CONFIG = { + "general": { + "verbosity": 1, + "engine": "xtb", + "max_cycles": 100, + }, + "xtb": {}, + "orca": {}, + } + # Step 2: Find the configuration file (CLI provided or default search) + config_file = find_config_file(args["general"]["config"]) + + # Step 3: Load the configuration + if config_file: + print(f"Reading configuration from file: '{config_file}'") + config = load_config(config_file) + else: + config = DEFAULT_CONFIG + + # Step 4: Merge with CLI arguments, giving precedence to CLI + merged_config = merge_config_with_cli(config, args) + + # Use `final_config` in your program + if merged_config["general"]["verbosity"] > 1: + print(merged_config) + raise SystemExit(generator(merged_config)) + + +def find_config_file(cli_config_path: str | Path | None = None) -> Path | None: + """ + Finds the configuration file. If a path is provided via CLI, use it. + Otherwise, search in predefined locations. + """ + # CLI provided config file + if cli_config_path: + config_path = Path(cli_config_path).resolve() + if config_path.is_file(): + return config_path + raise FileNotFoundError(f"Configuration file not found at {cli_config_path}") + + # Search paths + search_paths = [ + Path.home() / "mlmgen.toml", # $USER/mlmgen.toml + Path.cwd() / "mlmgen.toml", # Current directory + ] + + # Find the config file + for path in search_paths: + if path.is_file(): + return path + + # If no config file is found, raise a warning + warnings.warn("No configuration file found. Using default configuration.") + return None + + +def load_config(config_file): + """ + Load the configuration from the provided TOML file. + """ + return toml.load(config_file) + + +def merge_config_with_cli(config, cli_args_dict): + """ + Merge CLI arguments with the configuration, giving precedence to CLI. + """ + + for subcommand in cli_args_dict.keys(): + for key, value in cli_args_dict[subcommand].items(): + if value is not None: + config[subcommand][key] = value + return config diff --git a/src/mlmgen/generator/main.py b/src/mlmgen/generator/main.py index 813aa95..a509199 100644 --- a/src/mlmgen/generator/main.py +++ b/src/mlmgen/generator/main.py @@ -8,7 +8,7 @@ from ..molecules import postprocess -def generator(inputdict: dict) -> int: +def generator(config: dict) -> int: """ Generate a molecule. """ @@ -22,18 +22,19 @@ def generator(inputdict: dict) -> int: # __/ | # |___/ - if inputdict["engine"] == "xtb": + if config["general"]["engine"] == "xtb": try: xtb_path = get_xtb_path(["xtb_dev", "xtb"]) if not xtb_path: raise ImportError("xtb not found.") except ImportError as e: raise ImportError("xtb not found.") from e - engine = XTB(xtb_path, inputdict["verbosity"]) + engine = XTB(xtb_path, config["general"]["verbosity"]) else: raise NotImplementedError("Engine not implemented.") - for cycle in range(inputdict["max_cycles"]): + print(f"Config: {config}") + for cycle in range(config["general"]["max_cycles"]): print(f"Cycle {cycle + 1}...") # _____ _ # / ____| | | @@ -42,11 +43,7 @@ def generator(inputdict: dict) -> int: # | |__| | __/ | | | __/ | | (_| | || (_) | | # \_____|\___|_| |_|\___|_| \__,_|\__\___/|_| - if inputdict["input"]: - print(f"Input file: {input}") - raise NotImplementedError("Input file not implemented.") - else: - mol = generate_random_molecule(inputdict["verbosity"]) + mol = generate_random_molecule(config["general"]["verbosity"]) try: # ____ _ _ _ @@ -58,7 +55,7 @@ def generator(inputdict: dict) -> int: # | | # |_| optimized_molecule = postprocess( - mol=mol, engine=engine, verbosity=inputdict["verbosity"] + mol=mol, engine=engine, verbosity=config["general"]["verbosity"] ) print("Postprocessing successful. Optimized molecule:") print(optimized_molecule) @@ -66,7 +63,7 @@ def generator(inputdict: dict) -> int: return 0 except RuntimeError as e: print(f"Postprocessing failed for cycle {cycle + 1}.\n") - if inputdict["verbosity"] > 1: + if config["general"]["verbosity"] > 1: print(e) continue raise RuntimeError("Postprocessing failed for all cycles.")