Skip to content

Commit

Permalink
self-consistent generation, better metal check, better printout
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Müller <[email protected]>
marcelmbn committed Aug 13, 2024
1 parent 379b57c commit dad8d7d
Showing 6 changed files with 86 additions and 48 deletions.
8 changes: 8 additions & 0 deletions src/mlmgen/cli/cli_parser.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,14 @@ def cli_parser(argv: Sequence[str] | None = None) -> argparse.Namespace:
default="xtb",
help="QM engine to use.",
)
parser.add_argument(
"-mc",
"--max-cycles",
type=int,
default=100,
required=False,
help="Maximum number of optimization cycles.",
)
args = parser.parse_args(argv)

return args
3 changes: 1 addition & 2 deletions src/mlmgen/cli/entrypoint.py
Original file line number Diff line number Diff line change
@@ -17,5 +17,4 @@ def console_entry_point(argv: Sequence[str] | None = None) -> int:
args = cl(argv)
# convert args to dictionary
kwargs = vars(args)
exitcode = generator(kwargs)
return exitcode
raise SystemExit(generator(kwargs))
59 changes: 34 additions & 25 deletions src/mlmgen/generator/main.py
Original file line number Diff line number Diff line change
@@ -33,28 +33,37 @@ def generator(inputdict: dict) -> int:
else:
raise NotImplementedError("Engine not implemented.")

# _____ _
# / ____| | |
# | | __ ___ _ __ ___ _ __ __ _| |_ ___ _ __
# | | |_ |/ _ \ '_ \ / _ \ '__/ _` | __/ _ \| '__|
# | |__| | __/ | | | __/ | | (_| | || (_) | |
# \_____|\___|_| |_|\___|_| \__,_|\__\___/|_|

if inputdict["input"]:
print(f"Input file: {input}")
else:
mol = generate_random_molecule(inputdict["verbosity"])

# ____ _ _ _
# / __ \ | | (_) (_)
# | | | |_ __ | |_ _ _ __ ___ _ _______
# | | | | '_ \| __| | '_ ` _ \| |_ / _ \
# | |__| | |_) | |_| | | | | | | |/ / __/
# \____/| .__/ \__|_|_| |_| |_|_/___\___|
# | |
# |_|
optimized_molecule = postprocess(
mol=mol, engine=engine, verbosity=inputdict["verbosity"]
)
optimized_molecule.write_xyz_to_file("optimized_molecule.xyz")
return 0
for cycle in range(inputdict["max_cycles"]):
print(f"Cycle {cycle + 1}...")
# _____ _
# / ____| | |
# | | __ ___ _ __ ___ _ __ __ _| |_ ___ _ __
# | | |_ |/ _ \ '_ \ / _ \ '__/ _` | __/ _ \| '__|
# | |__| | __/ | | | __/ | | (_| | || (_) | |
# \_____|\___|_| |_|\___|_| \__,_|\__\___/|_|

if inputdict["input"]:
print(f"Input file: {input}")
else:
mol = generate_random_molecule(inputdict["verbosity"])

try:
# ____ _ _ _
# / __ \ | | (_) (_)
# | | | |_ __ | |_ _ _ __ ___ _ _______
# | | | | '_ \| __| | '_ ` _ \| |_ / _ \
# | |__| | |_) | |_| | | | | | | |/ / __/
# \____/| .__/ \__|_|_| |_| |_|_/___\___|
# | |
# |_|
optimized_molecule = postprocess(
mol=mol, engine=engine, verbosity=inputdict["verbosity"]
)
print("Postprocessing successful. Optimized molecule:")
print(optimized_molecule)
optimized_molecule.write_xyz_to_file("optimized_molecule.xyz")
return 0
except RuntimeError:
print(f"Postprocessing failed for cycle {cycle + 1}.\n")
continue
raise RuntimeError("Postprocessing failed for all cycles.")
58 changes: 37 additions & 21 deletions src/mlmgen/molecules/generate_molecule.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,14 @@ def generate_random_molecule(verbosity: int = 1) -> Molecule:
mol = Molecule()
mol.atlist = generate_atom_list(verbosity)
mol.num_atoms = np.sum(mol.atlist)
mol.xyz, mol.ati = generate_coordinates(mol.atlist, 3.0, 1.2)
mol.xyz, mol.ati = generate_coordinates(
at=mol.atlist, scaling=3.0, dist_threshold=1.2, verbosity=verbosity
)
mol.charge = set_random_charge(mol.ati, verbosity)

# if verbosity > 0, print the molecule and its sum formula
if verbosity > 0:
# if verbosity > 1, print the molecule
if verbosity > 1:
print(mol)
print(mol.sum_formula())

return mol

@@ -109,6 +110,26 @@ def generate_atom_list(verbosity: int = 1) -> np.ndarray:
# Add a random number of atoms of the defined type
natoms[ati] = natoms[ati] + np.random.randint(0, 3)

# > If too many alkaline and alkine earth metals are included, restart generation
metals = (2, 3, 10, 11, 18, 19, 36, 37, 54, 55)
nmetals = 0
for i in metals:
nmetals += natoms[i]
if nmetals > 3:
# reduce number of metals starting from 2, going to 55
while nmetals > 3:
for i in metals:
if natoms[i] > 0:
natoms[i] = natoms[i] - 1
nmetals -= 1
if nmetals <= 3:
break

# If too many transition or lanthanide metals are included, reduce their number
for i in get_metal_z():
if natoms[i] > 1:
natoms[i] = 1

# Add Elements between B and F (5-9)
for _ in range(5):
i = np.random.randint(4, 10)
@@ -122,27 +143,11 @@ def generate_atom_list(verbosity: int = 1) -> np.ndarray:
j = 1 + int(randint * minnat * 1.2)
natoms[0] = natoms[0] + j

# > If too many metals are included, restart generation
metals = (3, 4, 11, 12, 19, 20, 37, 38, 55, 56)
nmetals = 0
for i in metals:
if natoms[i] > 0:
nmetals += 1
if nmetals > 3:
# reduce number of metals starting from 3, going to 56
while nmetals > 3:
for i in metals:
if natoms[i] > 0:
natoms[i] = natoms[i] - 1
nmetals -= 1
if nmetals == 3:
break

return natoms


def generate_coordinates(
at: np.ndarray, scaling: float, dist_threshold: float
at: np.ndarray, scaling: float, dist_threshold: float, verbosity: int = 1
) -> tuple[np.ndarray, np.ndarray]:
"""
Generate random coordinates for a molecule.
@@ -154,6 +159,8 @@ def generate_coordinates(
xyz = xyz * eff_scaling
# do while check_distances is False
while not check_distances(xyz, dist_threshold):
if verbosity > 1:
print("Distance check failed. Regenerating coordinates...")
xyz, ati = generate_random_coordinates(at)
eff_scaling = eff_scaling * 1.3
xyz = xyz * eff_scaling
@@ -195,3 +202,12 @@ def check_distances(xyz: np.ndarray, threshold: float) -> bool:
if r < threshold:
return False
return True


def get_metal_z() -> list[int]:
"""
Get the atomic numbers of transition metals and lanthanides, for which different rules apply.
"""
metals = list(range(20, 30)) + list(range(38, 48)) + list(range(56, 80))

return metals
1 change: 1 addition & 0 deletions src/mlmgen/molecules/molecule.py
Original file line number Diff line number Diff line change
@@ -170,6 +170,7 @@ def __str__(self) -> str:
returnstr += f"# unpaired electrons: {self.uhf}\n"
if self._atlist.size:
returnstr += f"atomic numbers: {self.atlist}\n"
returnstr += f"sum formula: {self.sum_formula()}\n"
if self._xyz.size:
returnstr += f"atomic coordinates:\n{self.xyz}\n"
if self._ati.size:
5 changes: 5 additions & 0 deletions src/mlmgen/molecules/postprocess.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,9 @@ def postprocess(mol: Molecule, engine: QMMethod, verbosity: int = 1) -> Molecule
"""
Postprocess the molecule.
"""
# Optimize the initial random molecule
optmol = engine.optimize(mol)
# Get all fragments
fragmols = detect_fragments(optmol, verbosity)

if verbosity > 1:
@@ -21,7 +23,10 @@ def postprocess(mol: Molecule, engine: QMMethod, verbosity: int = 1) -> Molecule
Path(f"fragment_{i}").mkdir(exist_ok=True)
fragmol.write_xyz_to_file(f"fragment_{i}/fragment_{i}.xyz")

# Optimize the first fragment
optfragmol = engine.optimize(fragmols[0])
# Differentiate again in fragments and take only the first one
optfragmol = detect_fragments(optfragmol, verbosity)[0]

return optfragmol

0 comments on commit dad8d7d

Please sign in to comment.