Skip to content

Commit

Permalink
add tests for fragment detection and molecule generation; correct bug…
Browse files Browse the repository at this point in the history
… in distance check

Signed-off-by: Marcel Müller <[email protected]>
  • Loading branch information
marcelmbn committed Aug 13, 2024
1 parent dad8d7d commit e71f1fe
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/mlmgen/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "0.1.1.dev16+g811738b.d20240813"
__version_tuple__ = version_tuple = (0, 1, 1, "dev16", "g811738b.d20240813")
__version__ = version = "0.1.1.dev19+gdad8d7d.d20240813"
__version_tuple__ = version_tuple = (0, 1, 1, "dev19", "gdad8d7d.d20240813")
14 changes: 12 additions & 2 deletions src/mlmgen/molecules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
"""

from .molecule import Molecule
from .generate_molecule import generate_random_molecule, generate_coordinates
from .postprocess import postprocess
from .generate_molecule import (
generate_random_molecule,
generate_coordinates,
generate_atom_list,
get_metal_z,
check_distances,
)
from .postprocess import postprocess, detect_fragments
from .miscellaneous import set_random_charge

__all__ = [
"Molecule",
"generate_random_molecule",
"generate_coordinates",
"generate_atom_list",
"get_metal_z",
"postprocess",
"detect_fragments",
"set_random_charge",
"check_distances",
]
6 changes: 3 additions & 3 deletions src/mlmgen/molecules/generate_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def check_distances(xyz: np.ndarray, threshold: float) -> bool:
Check if the distances between atoms are larger than a threshold.
"""
# go through the atoms dimension of the xyz array
for i in range(1, xyz.shape[1]):
for j in range(i - 1):
r = np.linalg.norm(xyz[:, i] - xyz[:, j])
for i in range(xyz.shape[0] - 1):
for j in range(i + 1, xyz.shape[0]):
r = np.linalg.norm(xyz[i, :] - xyz[j, :])
if r < threshold:
return False
return True
Expand Down
10 changes: 10 additions & 0 deletions test/test_molecules/fixtures/H2O2B2I1Os1.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
8
Generated by mlmgen-v0.1.1.dev19+gdad8d7d.d20240813
H 1.8881922 0.0434471 6.4036416
H 2.3614382 1.5848368 4.1324938
B 2.9665417 3.1071877 6.8367337
B 1.2903809 -0.0355866 5.2252287
O 2.8188111 2.2115904 7.6804091
O 2.9569763 3.4427774 5.6277026
I 0.1039807 -1.5020698 4.4901998
Os 2.3947779 1.4743942 5.6803414
17 changes: 17 additions & 0 deletions test/test_molecules/fixtures/H6O2B2Ne2I1Os1Tl1.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
15
Generated by mlmgen-v0.1.1.dev19+gdad8d7d.d20240813
H 1.8881922 0.0434471 6.4036416
H 2.3614382 1.5848368 4.1324938
H 4.4075523 -1.1311187 4.6335760
H -18.8463718 31.0433592 0.1833077
H 3.8825457 -0.5200946 4.7973741
H 15.1931043 11.9104232 -29.4429174
B 2.9665417 3.1071877 6.8367337
B 1.2903809 -0.0355866 5.2252287
O 2.8188111 2.2115904 7.6804091
O 2.9569763 3.4427774 5.6277026
Ne -12.5685006 -25.1495934 -12.9021224
Ne 5.1386813 1.5739968 5.9230793
I 0.1039807 -1.5020698 4.4901998
Os 2.3947779 1.4743942 5.6803414
Tl -15.2759053 -25.7803943 -13.4756278
84 changes: 83 additions & 1 deletion test/test_molecules/test_generate_molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
"""

from __future__ import annotations
import pytest
import numpy as np
from mlmgen.molecules import generate_random_molecule, generate_coordinates # type: ignore
from mlmgen.molecules import ( # type: ignore
generate_random_molecule,
generate_coordinates,
generate_atom_list,
get_metal_z,
check_distances,
)
from mlmgen.molecules.molecule import Molecule # type: ignore


Expand Down Expand Up @@ -43,6 +50,81 @@ def test_generate_coordinates() -> None:
assert mol.num_atoms == len(mol.ati)


def test_generate_atom_list() -> None:
"""
Test the generation of an array of atomic numbers.
"""
atlist = generate_atom_list()
assert atlist.shape == (102,)
assert np.sum(atlist) > 0
# check that for the transition and lanthanide metals, the occurence is never greater than 1
for z in get_metal_z():
assert atlist[z] <= 1
alkmetals = (2, 3, 10, 11, 18, 19, 36, 37, 54, 55)
# check that the sum of alkali and alkaline earth metals is never greater than 3
assert np.sum([atlist[z] for z in alkmetals]) <= 3


@pytest.mark.parametrize(
"xyz, threshold, expected, description",
[
(
np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]),
0.5,
True,
"Two atoms with distance greater than threshold (1.0 > 0.5)",
),
(
np.array([[0.0, 0.0, 0.0], [0.4, 0.0, 0.0]]),
0.5,
False,
"Two atoms with distance less than threshold (0.4 < 0.5)",
),
(
np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]),
0.5,
True,
"Three atoms in a line with distances greater than threshold",
),
(
np.array([[0.0, 0.0, 0.0], [0.4, 0.0, 0.0], [1.0, 0.0, 0.0]]),
0.5,
False,
"Three atoms with one pair close together: distance between first two is less than threshold",
),
(
np.array([[0.0, 0.0, 0.0]]),
0.5,
True,
"Single atom, no distances to compare",
),
(
np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
0.5,
False,
"Two atoms at identical positions: distance is zero, less than threshold",
),
(
np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]),
1.7320,
True,
"Two atoms with diagonal distance just above threshold (sqrt(3) ≈ 1.732)",
),
],
ids=[
"far_apart",
"close_together",
"three_in_line",
"three_with_one_close",
"single_atom",
"two_identical",
"diagonal_distance",
],
)
def test_check_distances(xyz, threshold, expected, description):
assert check_distances(xyz, threshold) == expected


def test_dummy() -> None:
"""
Test the dummy function.
Expand Down
23 changes: 23 additions & 0 deletions test/test_molecules/test_miscellaneous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Test the miscellaneous functions in the molecules module.
"""

import pytest
import numpy as np
from mlmgen.molecules.miscellaneous import set_random_charge # type: ignore


# CAUTION: We use 0-based indexing for atoms and molecules!
@pytest.mark.parametrize(
"atom_types, expected_charge",
[
(np.array([5, 7, 0]), [-1, 1]),
(np.array([5, 7, 0, 0]), [-2, 0, 2]),
],
ids=["odd", "even"],
)
def test_set_random_charge(atom_types, expected_charge):
"""
Test the set_random_charge function.
"""
assert set_random_charge(atom_types) in expected_charge
51 changes: 51 additions & 0 deletions test/test_molecules/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Test the postprocessing functions in the molecules module.
"""

from __future__ import annotations

from pathlib import Path
import numpy as np
import pytest
from mlmgen.molecules import detect_fragments # type: ignore
from mlmgen.molecules.molecule import Molecule # type: ignore


@pytest.fixture
def mol_H6O2B2Ne2I1Os1Tl1() -> Molecule:
"""
Load the molecule H6O2B2Ne2I1Os1Tl1 from 'fixtures/H6O2B2Ne2I1Os1Tl1.xyz'.
"""
mol = Molecule("H6O2B2Ne2I1Os1Tl1")
# get the Path of this file
path = Path(__file__).resolve().parent
xyz_file = path / "fixtures/H6O2B2Ne2I1Os1Tl1.xyz"
mol.read_xyz_from_file(xyz_file)
return mol


@pytest.fixture
def mol_H2O2B2I1Os1() -> Molecule:
"""
Load the molecule H2O2B2I1Os1 from 'fixtures/H2O2B2I1Os1.xyz'.
"""
mol = Molecule("H2O2B2I1Os1")
# get the Path of this file
path = Path(__file__).resolve().parent
xyz_file = path / "fixtures/H2O2B2I1Os1.xyz"
mol.read_xyz_from_file(xyz_file)
return mol


def test_detect_fragments_H6O2B2Ne2I1Os1Tl1(
mol_H6O2B2Ne2I1Os1Tl1: Molecule,
mol_H2O2B2I1Os1: Molecule,
) -> None:
"""
Test the detection of fragments in the molecule H2O2B2I1Os1.
"""
fragmols = detect_fragments(mol_H6O2B2Ne2I1Os1Tl1, verbosity=0)
assert len(fragmols) == 7

# check that the first fragment is equal to the molecule H2O2B2I1Os1
assert np.allclose(fragmols[0].xyz, mol_H2O2B2I1Os1.xyz)

0 comments on commit e71f1fe

Please sign in to comment.