diff --git a/src/mlmgen/cli/cli_parser.py b/src/mlmgen/cli/cli_parser.py index 57927cf..963f2db 100644 --- a/src/mlmgen/cli/cli_parser.py +++ b/src/mlmgen/cli/cli_parser.py @@ -32,6 +32,14 @@ def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace: default="xtb", help="QM engine to use.", ) + parser.add_argument( + "-mc", + "--max-cycles", + type=int, + default=100, + required=False, + help="Maximum number of optimization cycles.", + ) args = parser.parse_args(argv) return args diff --git a/src/mlmgen/cli/entrypoint.py b/src/mlmgen/cli/entrypoint.py index 10a8661..217d500 100644 --- a/src/mlmgen/cli/entrypoint.py +++ b/src/mlmgen/cli/entrypoint.py @@ -17,5 +17,4 @@ def console_entry_point(argv: Sequence[str] | None = None) -> int: args = cl(argv) # convert args to dictionary kwargs = vars(args) - exitcode = generator(kwargs) - return exitcode + raise SystemExit(generator(kwargs)) diff --git a/src/mlmgen/generator/main.py b/src/mlmgen/generator/main.py index dc4da72..1eaaff1 100644 --- a/src/mlmgen/generator/main.py +++ b/src/mlmgen/generator/main.py @@ -33,28 +33,37 @@ def generator(inputdict: dict) -> int: else: raise NotImplementedError("Engine not implemented.") - # _____ _ - # / ____| | | - # | | __ ___ _ __ ___ _ __ __ _| |_ ___ _ __ - # | | |_ |/ _ \ '_ \ / _ \ '__/ _` | __/ _ \| '__| - # | |__| | __/ | | | __/ | | (_| | || (_) | | - # \_____|\___|_| |_|\___|_| \__,_|\__\___/|_| - - if inputdict["input"]: - print(f"Input file: {input}") - else: - mol = generate_random_molecule(inputdict["verbosity"]) - - # ____ _ _ _ - # / __ \ | | (_) (_) - # | | | |_ __ | |_ _ _ __ ___ _ _______ - # | | | | '_ \| __| | '_ ` _ \| |_ / _ \ - # | |__| | |_) | |_| | | | | | | |/ / __/ - # \____/| .__/ \__|_|_| |_| |_|_/___\___| - # | | - # |_| - optimized_molecule = postprocess( - mol=mol, engine=engine, verbosity=inputdict["verbosity"] - ) - optimized_molecule.write_xyz_to_file("optimized_molecule.xyz") - return 0 + for cycle in range(inputdict["max_cycles"]): + print(f"Cycle {cycle + 1}...") + # _____ _ + # / ____| | | + # | | __ ___ _ __ ___ _ __ __ _| |_ ___ _ __ + # | | |_ |/ _ \ '_ \ / _ \ '__/ _` | __/ _ \| '__| + # | |__| | __/ | | | __/ | | (_| | || (_) | | + # \_____|\___|_| |_|\___|_| \__,_|\__\___/|_| + + if inputdict["input"]: + print(f"Input file: {input}") + else: + mol = generate_random_molecule(inputdict["verbosity"]) + + try: + # ____ _ _ _ + # / __ \ | | (_) (_) + # | | | |_ __ | |_ _ _ __ ___ _ _______ + # | | | | '_ \| __| | '_ ` _ \| |_ / _ \ + # | |__| | |_) | |_| | | | | | | |/ / __/ + # \____/| .__/ \__|_|_| |_| |_|_/___\___| + # | | + # |_| + optimized_molecule = postprocess( + mol=mol, engine=engine, verbosity=inputdict["verbosity"] + ) + print("Postprocessing successful. Optimized molecule:") + print(optimized_molecule) + optimized_molecule.write_xyz_to_file("optimized_molecule.xyz") + return 0 + except RuntimeError: + print(f"Postprocessing failed for cycle {cycle + 1}.\n") + continue + raise RuntimeError("Postprocessing failed for all cycles.") diff --git a/src/mlmgen/molecules/generate_molecule.py b/src/mlmgen/molecules/generate_molecule.py index 7b4a0b7..66a93a8 100644 --- a/src/mlmgen/molecules/generate_molecule.py +++ b/src/mlmgen/molecules/generate_molecule.py @@ -16,13 +16,14 @@ def generate_random_molecule(verbosity: int = 1) -> Molecule: mol = Molecule() mol.atlist = generate_atom_list(verbosity) mol.num_atoms = np.sum(mol.atlist) - mol.xyz, mol.ati = generate_coordinates(mol.atlist, 3.0, 1.2) + mol.xyz, mol.ati = generate_coordinates( + at=mol.atlist, scaling=3.0, dist_threshold=1.2, verbosity=verbosity + ) mol.charge = set_random_charge(mol.ati, verbosity) - # if verbosity > 0, print the molecule and its sum formula - if verbosity > 0: + # if verbosity > 1, print the molecule + if verbosity > 1: print(mol) - print(mol.sum_formula()) return mol @@ -109,6 +110,26 @@ def generate_atom_list(verbosity: int = 1) -> np.ndarray: # Add a random number of atoms of the defined type natoms[ati] = natoms[ati] + np.random.randint(0, 3) + # > If too many alkaline and alkine earth metals are included, restart generation + metals = (2, 3, 10, 11, 18, 19, 36, 37, 54, 55) + nmetals = 0 + for i in metals: + nmetals += natoms[i] + if nmetals > 3: + # reduce number of metals starting from 2, going to 55 + while nmetals > 3: + for i in metals: + if natoms[i] > 0: + natoms[i] = natoms[i] - 1 + nmetals -= 1 + if nmetals <= 3: + break + + # If too many transition or lanthanide metals are included, reduce their number + for i in get_metal_z(): + if natoms[i] > 1: + natoms[i] = 1 + # Add Elements between B and F (5-9) for _ in range(5): i = np.random.randint(4, 10) @@ -122,27 +143,11 @@ def generate_atom_list(verbosity: int = 1) -> np.ndarray: j = 1 + int(randint * minnat * 1.2) natoms[0] = natoms[0] + j - # > If too many metals are included, restart generation - metals = (3, 4, 11, 12, 19, 20, 37, 38, 55, 56) - nmetals = 0 - for i in metals: - if natoms[i] > 0: - nmetals += 1 - if nmetals > 3: - # reduce number of metals starting from 3, going to 56 - while nmetals > 3: - for i in metals: - if natoms[i] > 0: - natoms[i] = natoms[i] - 1 - nmetals -= 1 - if nmetals == 3: - break - return natoms def generate_coordinates( - at: np.ndarray, scaling: float, dist_threshold: float + at: np.ndarray, scaling: float, dist_threshold: float, verbosity: int = 1 ) -> tuple[np.ndarray, np.ndarray]: """ Generate random coordinates for a molecule. @@ -154,6 +159,8 @@ def generate_coordinates( xyz = xyz * eff_scaling # do while check_distances is False while not check_distances(xyz, dist_threshold): + if verbosity > 1: + print("Distance check failed. Regenerating coordinates...") xyz, ati = generate_random_coordinates(at) eff_scaling = eff_scaling * 1.3 xyz = xyz * eff_scaling @@ -195,3 +202,12 @@ def check_distances(xyz: np.ndarray, threshold: float) -> bool: if r < threshold: return False return True + + +def get_metal_z() -> list[int]: + """ + Get the atomic numbers of transition metals and lanthanides, for which different rules apply. + """ + metals = list(range(20, 30)) + list(range(38, 48)) + list(range(56, 80)) + + return metals diff --git a/src/mlmgen/molecules/molecule.py b/src/mlmgen/molecules/molecule.py index 2ac2108..e9fb782 100644 --- a/src/mlmgen/molecules/molecule.py +++ b/src/mlmgen/molecules/molecule.py @@ -170,6 +170,7 @@ def __str__(self) -> str: returnstr += f"# unpaired electrons: {self.uhf}\n" if self._atlist.size: returnstr += f"atomic numbers: {self.atlist}\n" + returnstr += f"sum formula: {self.sum_formula()}\n" if self._xyz.size: returnstr += f"atomic coordinates:\n{self.xyz}\n" if self._ati.size: diff --git a/src/mlmgen/molecules/postprocess.py b/src/mlmgen/molecules/postprocess.py index 0038b0a..66aaf6c 100644 --- a/src/mlmgen/molecules/postprocess.py +++ b/src/mlmgen/molecules/postprocess.py @@ -12,7 +12,9 @@ def postprocess(mol: Molecule, engine: QMMethod, verbosity: int = 1) -> Molecule """ Postprocess the molecule. """ + # Optimize the initial random molecule optmol = engine.optimize(mol) + # Get all fragments fragmols = detect_fragments(optmol, verbosity) if verbosity > 1: @@ -21,7 +23,10 @@ def postprocess(mol: Molecule, engine: QMMethod, verbosity: int = 1) -> Molecule Path(f"fragment_{i}").mkdir(exist_ok=True) fragmol.write_xyz_to_file(f"fragment_{i}/fragment_{i}.xyz") + # Optimize the first fragment optfragmol = engine.optimize(fragmols[0]) + # Differentiate again in fragments and take only the first one + optfragmol = detect_fragments(optfragmol, verbosity)[0] return optfragmol