Skip to content

Commit

Permalink
introduce TOML configuration file
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Müller <[email protected]>
  • Loading branch information
marcelmbn committed Aug 14, 2024
1 parent 4fb8a78 commit faeae97
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 23 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ repos:
rev: v1.11.1
hooks:
- id: mypy
additional_dependencies: [types-toml]
default_language_version:
python: python3.12
27 changes: 27 additions & 0 deletions mlmgen.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Default configuration for the 'MindLess Molecule GENerator' (MLMGen)
# Following file locations are searched for in the following order:
# 1. Location specified by the `--config < str | Path >` command-line argument
# 2. Current working directory (`Path.cwd()`)
# 3. User's home directory (`Path.home()`)

[general]
# Verbosity level defining the printout: Options: 0 = silent, 1 = default, 2 = verbose
verbosity = 1

# Quantum Mechanics (QM) engine to use. Options: 'xtb', 'orca'
engine = "xtb"

# Maximum number of optimization cycles. Options: <int>
max_cycles = 100

[xtb]
# TODO
# Specific configurations for the XTB engine (if needed)
# xtb_option_1 = "value1"
# xtb_option_2 = "value2"

[orca]
# TODO
# Specific configurations for the ORCA engine (if needed)
# orca_option_1 = "value1"
# orca_option_2 = "value2"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
"Topic :: Scientific/Engineering",
"Typing :: Typed",
]
dependencies = ["numpy", "networkx"]
dependencies = ["numpy", "networkx", "toml"]
dynamic = ["version"]

[project.optional-dependencies]
Expand All @@ -33,6 +33,7 @@ dev = [
"pytest",
"tox",
"setuptools_scm>=8",
"types-toml",
]

[project.scripts]
Expand Down
4 changes: 2 additions & 2 deletions src/mlmgen/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "0.1.1.dev21+g848a76c.d20240814"
__version_tuple__ = version_tuple = (0, 1, 1, "dev21", "g848a76c.d20240814")
__version__ = version = "0.1.1.dev22+g4fb8a78.d20240814"
__version_tuple__ = version_tuple = (0, 1, 1, "dev22", "g4fb8a78.d20240814")
32 changes: 26 additions & 6 deletions src/mlmgen/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
from ..__version__ import __version__


def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace:
def cli_parser(argv: Sequence[str] | None = None) -> dict:
"""
Parse command line arguments.
"""
# get command line argument
parser = argparse.ArgumentParser()
# General arguments
parser.add_argument(
"-i", "--input", type=argparse.FileType("r"), help="Input file.", required=False
"-c",
"--config",
type=str,
help="Input file.",
required=False,
)
parser.add_argument(
"-v", "--version", action="version", version=f"%(prog)s {__version__}"
Expand All @@ -21,25 +26,40 @@ def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace:
"--verbosity",
type=int,
choices=[0, 1, 2],
default=1,
help="Verbosity level (0, 1, or 2).",
)
parser.add_argument(
"-e",
"--engine",
type=str,
choices=["xtb", "orca"],
default="xtb",
help="QM engine to use.",
)
parser.add_argument(
"-mc",
"--max-cycles",
type=int,
default=100,
required=False,
help="Maximum number of optimization cycles.",
)
# XTB specific arguments
# TODO: Add XTB specific arguments
# ORCA specific arguments
# TODO: Add ORCA specific arguments
args = parser.parse_args(argv)
args_dict = vars(args)

# General arguments
rev_args_dict = {}
rev_args_dict["general"] = {
"config": args_dict["config"],
"verbosity": args_dict["verbosity"],
"engine": args_dict["engine"],
"max_cycles": args_dict["max_cycles"],
}
# XTB specific arguments
rev_args_dict["xtb"] = {}
# ORCA specific arguments
rev_args_dict["orca"] = {}

return args
return rev_args_dict
84 changes: 81 additions & 3 deletions src/mlmgen/cli/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
import warnings

import toml

from ..generator import generator
from .cli_parser import cli_parser as cl
Expand All @@ -14,7 +18,81 @@ def console_entry_point(argv: Sequence[str] | None = None) -> int:
"""
Entrypoint for command line interface.
"""
# Step 1: Parse CLI arguments
args = cl(argv)
# convert args to dictionary
kwargs = vars(args)
raise SystemExit(generator(kwargs))

# Generate a default config that corresponds to this configuration file
# after parsing the toml file into a dictionary
DEFAULT_CONFIG = {
"general": {
"verbosity": 1,
"engine": "xtb",
"max_cycles": 100,
},
"xtb": {},
"orca": {},
}
# Step 2: Find the configuration file (CLI provided or default search)
config_file = find_config_file(args["general"]["config"])

# Step 3: Load the configuration
if config_file:
print(f"Reading configuration from file: '{config_file}'")
config = load_config(config_file)
else:
config = DEFAULT_CONFIG

# Step 4: Merge with CLI arguments, giving precedence to CLI
merged_config = merge_config_with_cli(config, args)

# Use `final_config` in your program
if merged_config["general"]["verbosity"] > 1:
print(merged_config)
raise SystemExit(generator(merged_config))


def find_config_file(cli_config_path: str | Path | None = None) -> Path | None:
"""
Finds the configuration file. If a path is provided via CLI, use it.
Otherwise, search in predefined locations.
"""
# CLI provided config file
if cli_config_path:
config_path = Path(cli_config_path).resolve()
if config_path.is_file():
return config_path
raise FileNotFoundError(f"Configuration file not found at {cli_config_path}")

# Search paths
search_paths = [
Path.home() / "mlmgen.toml", # $USER/mlmgen.toml
Path.cwd() / "mlmgen.toml", # Current directory
]

# Find the config file
for path in search_paths:
if path.is_file():
return path

# If no config file is found, raise a warning
warnings.warn("No configuration file found. Using default configuration.")
return None


def load_config(config_file):
"""
Load the configuration from the provided TOML file.
"""
return toml.load(config_file)


def merge_config_with_cli(config, cli_args_dict):
"""
Merge CLI arguments with the configuration, giving precedence to CLI.
"""

for subcommand in cli_args_dict.keys():
for key, value in cli_args_dict[subcommand].items():
if value is not None:
config[subcommand][key] = value
return config
19 changes: 8 additions & 11 deletions src/mlmgen/generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..molecules import postprocess


def generator(inputdict: dict) -> int:
def generator(config: dict) -> int:
"""
Generate a molecule.
"""
Expand All @@ -22,18 +22,19 @@ def generator(inputdict: dict) -> int:
# __/ |
# |___/

if inputdict["engine"] == "xtb":
if config["general"]["engine"] == "xtb":
try:
xtb_path = get_xtb_path(["xtb_dev", "xtb"])
if not xtb_path:
raise ImportError("xtb not found.")
except ImportError as e:
raise ImportError("xtb not found.") from e
engine = XTB(xtb_path, inputdict["verbosity"])
engine = XTB(xtb_path, config["general"]["verbosity"])
else:
raise NotImplementedError("Engine not implemented.")

for cycle in range(inputdict["max_cycles"]):
print(f"Config: {config}")
for cycle in range(config["general"]["max_cycles"]):
print(f"Cycle {cycle + 1}...")
# _____ _
# / ____| | |
Expand All @@ -42,11 +43,7 @@ def generator(inputdict: dict) -> int:
# | |__| | __/ | | | __/ | | (_| | || (_) | |
# \_____|\___|_| |_|\___|_| \__,_|\__\___/|_|

if inputdict["input"]:
print(f"Input file: {input}")
raise NotImplementedError("Input file not implemented.")
else:
mol = generate_random_molecule(inputdict["verbosity"])
mol = generate_random_molecule(config["general"]["verbosity"])

try:
# ____ _ _ _
Expand All @@ -58,15 +55,15 @@ def generator(inputdict: dict) -> int:
# | |
# |_|
optimized_molecule = postprocess(
mol=mol, engine=engine, verbosity=inputdict["verbosity"]
mol=mol, engine=engine, verbosity=config["general"]["verbosity"]
)
print("Postprocessing successful. Optimized molecule:")
print(optimized_molecule)
optimized_molecule.write_xyz_to_file("optimized_molecule.xyz")
return 0
except RuntimeError as e:
print(f"Postprocessing failed for cycle {cycle + 1}.\n")
if inputdict["verbosity"] > 1:
if config["general"]["verbosity"] > 1:
print(e)
continue
raise RuntimeError("Postprocessing failed for all cycles.")

0 comments on commit faeae97

Please sign in to comment.