From e71f1fef8f3e5aa6afbff4e88eb8a9ce01ef5266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=BCller?= Date: Tue, 13 Aug 2024 18:49:38 +0200 Subject: [PATCH] add tests for fragment detection and molecule generation; correct bug in distance check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marcel Müller --- src/mlmgen/__version__.py | 4 +- src/mlmgen/molecules/__init__.py | 14 +++- src/mlmgen/molecules/generate_molecule.py | 6 +- test/test_molecules/fixtures/H2O2B2I1Os1.xyz | 10 +++ .../fixtures/H6O2B2Ne2I1Os1Tl1.xyz | 17 ++++ .../test_molecules/test_generate_molecules.py | 84 ++++++++++++++++++- test/test_molecules/test_miscellaneous.py | 23 +++++ test/test_molecules/test_postprocessing.py | 51 +++++++++++ 8 files changed, 201 insertions(+), 8 deletions(-) create mode 100644 test/test_molecules/fixtures/H2O2B2I1Os1.xyz create mode 100644 test/test_molecules/fixtures/H6O2B2Ne2I1Os1Tl1.xyz create mode 100644 test/test_molecules/test_miscellaneous.py create mode 100644 test/test_molecules/test_postprocessing.py diff --git a/src/mlmgen/__version__.py b/src/mlmgen/__version__.py index 739017c..a197331 100644 --- a/src/mlmgen/__version__.py +++ b/src/mlmgen/__version__.py @@ -13,5 +13,5 @@ __version_tuple__: VERSION_TUPLE version_tuple: VERSION_TUPLE -__version__ = version = "0.1.1.dev16+g811738b.d20240813" -__version_tuple__ = version_tuple = (0, 1, 1, "dev16", "g811738b.d20240813") +__version__ = version = "0.1.1.dev19+gdad8d7d.d20240813" +__version_tuple__ = version_tuple = (0, 1, 1, "dev19", "gdad8d7d.d20240813") diff --git a/src/mlmgen/molecules/__init__.py b/src/mlmgen/molecules/__init__.py index d09460f..ec11c71 100644 --- a/src/mlmgen/molecules/__init__.py +++ b/src/mlmgen/molecules/__init__.py @@ -3,14 +3,24 @@ """ from .molecule import Molecule -from .generate_molecule import generate_random_molecule, generate_coordinates -from .postprocess import postprocess +from .generate_molecule import ( + generate_random_molecule, + generate_coordinates, + generate_atom_list, + get_metal_z, + check_distances, +) +from .postprocess import postprocess, detect_fragments from .miscellaneous import set_random_charge __all__ = [ "Molecule", "generate_random_molecule", "generate_coordinates", + "generate_atom_list", + "get_metal_z", "postprocess", + "detect_fragments", "set_random_charge", + "check_distances", ] diff --git a/src/mlmgen/molecules/generate_molecule.py b/src/mlmgen/molecules/generate_molecule.py index 66a93a8..b31bd7f 100644 --- a/src/mlmgen/molecules/generate_molecule.py +++ b/src/mlmgen/molecules/generate_molecule.py @@ -196,9 +196,9 @@ def check_distances(xyz: np.ndarray, threshold: float) -> bool: Check if the distances between atoms are larger than a threshold. """ # go through the atoms dimension of the xyz array - for i in range(1, xyz.shape[1]): - for j in range(i - 1): - r = np.linalg.norm(xyz[:, i] - xyz[:, j]) + for i in range(xyz.shape[0] - 1): + for j in range(i + 1, xyz.shape[0]): + r = np.linalg.norm(xyz[i, :] - xyz[j, :]) if r < threshold: return False return True diff --git a/test/test_molecules/fixtures/H2O2B2I1Os1.xyz b/test/test_molecules/fixtures/H2O2B2I1Os1.xyz new file mode 100644 index 0000000..d17eb09 --- /dev/null +++ b/test/test_molecules/fixtures/H2O2B2I1Os1.xyz @@ -0,0 +1,10 @@ +8 +Generated by mlmgen-v0.1.1.dev19+gdad8d7d.d20240813 +H 1.8881922 0.0434471 6.4036416 +H 2.3614382 1.5848368 4.1324938 +B 2.9665417 3.1071877 6.8367337 +B 1.2903809 -0.0355866 5.2252287 +O 2.8188111 2.2115904 7.6804091 +O 2.9569763 3.4427774 5.6277026 +I 0.1039807 -1.5020698 4.4901998 +Os 2.3947779 1.4743942 5.6803414 diff --git a/test/test_molecules/fixtures/H6O2B2Ne2I1Os1Tl1.xyz b/test/test_molecules/fixtures/H6O2B2Ne2I1Os1Tl1.xyz new file mode 100644 index 0000000..380274c --- /dev/null +++ b/test/test_molecules/fixtures/H6O2B2Ne2I1Os1Tl1.xyz @@ -0,0 +1,17 @@ +15 +Generated by mlmgen-v0.1.1.dev19+gdad8d7d.d20240813 +H 1.8881922 0.0434471 6.4036416 +H 2.3614382 1.5848368 4.1324938 +H 4.4075523 -1.1311187 4.6335760 +H -18.8463718 31.0433592 0.1833077 +H 3.8825457 -0.5200946 4.7973741 +H 15.1931043 11.9104232 -29.4429174 +B 2.9665417 3.1071877 6.8367337 +B 1.2903809 -0.0355866 5.2252287 +O 2.8188111 2.2115904 7.6804091 +O 2.9569763 3.4427774 5.6277026 +Ne -12.5685006 -25.1495934 -12.9021224 +Ne 5.1386813 1.5739968 5.9230793 +I 0.1039807 -1.5020698 4.4901998 +Os 2.3947779 1.4743942 5.6803414 +Tl -15.2759053 -25.7803943 -13.4756278 diff --git a/test/test_molecules/test_generate_molecules.py b/test/test_molecules/test_generate_molecules.py index 4dc79f4..4b49f9e 100644 --- a/test/test_molecules/test_generate_molecules.py +++ b/test/test_molecules/test_generate_molecules.py @@ -3,8 +3,15 @@ """ from __future__ import annotations +import pytest import numpy as np -from mlmgen.molecules import generate_random_molecule, generate_coordinates # type: ignore +from mlmgen.molecules import ( # type: ignore + generate_random_molecule, + generate_coordinates, + generate_atom_list, + get_metal_z, + check_distances, +) from mlmgen.molecules.molecule import Molecule # type: ignore @@ -43,6 +50,81 @@ def test_generate_coordinates() -> None: assert mol.num_atoms == len(mol.ati) +def test_generate_atom_list() -> None: + """ + Test the generation of an array of atomic numbers. + """ + atlist = generate_atom_list() + assert atlist.shape == (102,) + assert np.sum(atlist) > 0 + # check that for the transition and lanthanide metals, the occurence is never greater than 1 + for z in get_metal_z(): + assert atlist[z] <= 1 + alkmetals = (2, 3, 10, 11, 18, 19, 36, 37, 54, 55) + # check that the sum of alkali and alkaline earth metals is never greater than 3 + assert np.sum([atlist[z] for z in alkmetals]) <= 3 + + +@pytest.mark.parametrize( + "xyz, threshold, expected, description", + [ + ( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), + 0.5, + True, + "Two atoms with distance greater than threshold (1.0 > 0.5)", + ), + ( + np.array([[0.0, 0.0, 0.0], [0.4, 0.0, 0.0]]), + 0.5, + False, + "Two atoms with distance less than threshold (0.4 < 0.5)", + ), + ( + np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]), + 0.5, + True, + "Three atoms in a line with distances greater than threshold", + ), + ( + np.array([[0.0, 0.0, 0.0], [0.4, 0.0, 0.0], [1.0, 0.0, 0.0]]), + 0.5, + False, + "Three atoms with one pair close together: distance between first two is less than threshold", + ), + ( + np.array([[0.0, 0.0, 0.0]]), + 0.5, + True, + "Single atom, no distances to compare", + ), + ( + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + 0.5, + False, + "Two atoms at identical positions: distance is zero, less than threshold", + ), + ( + np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), + 1.7320, + True, + "Two atoms with diagonal distance just above threshold (sqrt(3) ≈ 1.732)", + ), + ], + ids=[ + "far_apart", + "close_together", + "three_in_line", + "three_with_one_close", + "single_atom", + "two_identical", + "diagonal_distance", + ], +) +def test_check_distances(xyz, threshold, expected, description): + assert check_distances(xyz, threshold) == expected + + def test_dummy() -> None: """ Test the dummy function. diff --git a/test/test_molecules/test_miscellaneous.py b/test/test_molecules/test_miscellaneous.py new file mode 100644 index 0000000..bb7cff1 --- /dev/null +++ b/test/test_molecules/test_miscellaneous.py @@ -0,0 +1,23 @@ +""" +Test the miscellaneous functions in the molecules module. +""" + +import pytest +import numpy as np +from mlmgen.molecules.miscellaneous import set_random_charge # type: ignore + + +# CAUTION: We use 0-based indexing for atoms and molecules! +@pytest.mark.parametrize( + "atom_types, expected_charge", + [ + (np.array([5, 7, 0]), [-1, 1]), + (np.array([5, 7, 0, 0]), [-2, 0, 2]), + ], + ids=["odd", "even"], +) +def test_set_random_charge(atom_types, expected_charge): + """ + Test the set_random_charge function. + """ + assert set_random_charge(atom_types) in expected_charge diff --git a/test/test_molecules/test_postprocessing.py b/test/test_molecules/test_postprocessing.py new file mode 100644 index 0000000..4b42479 --- /dev/null +++ b/test/test_molecules/test_postprocessing.py @@ -0,0 +1,51 @@ +""" +Test the postprocessing functions in the molecules module. +""" + +from __future__ import annotations + +from pathlib import Path +import numpy as np +import pytest +from mlmgen.molecules import detect_fragments # type: ignore +from mlmgen.molecules.molecule import Molecule # type: ignore + + +@pytest.fixture +def mol_H6O2B2Ne2I1Os1Tl1() -> Molecule: + """ + Load the molecule H6O2B2Ne2I1Os1Tl1 from 'fixtures/H6O2B2Ne2I1Os1Tl1.xyz'. + """ + mol = Molecule("H6O2B2Ne2I1Os1Tl1") + # get the Path of this file + path = Path(__file__).resolve().parent + xyz_file = path / "fixtures/H6O2B2Ne2I1Os1Tl1.xyz" + mol.read_xyz_from_file(xyz_file) + return mol + + +@pytest.fixture +def mol_H2O2B2I1Os1() -> Molecule: + """ + Load the molecule H2O2B2I1Os1 from 'fixtures/H2O2B2I1Os1.xyz'. + """ + mol = Molecule("H2O2B2I1Os1") + # get the Path of this file + path = Path(__file__).resolve().parent + xyz_file = path / "fixtures/H2O2B2I1Os1.xyz" + mol.read_xyz_from_file(xyz_file) + return mol + + +def test_detect_fragments_H6O2B2Ne2I1Os1Tl1( + mol_H6O2B2Ne2I1Os1Tl1: Molecule, + mol_H2O2B2I1Os1: Molecule, +) -> None: + """ + Test the detection of fragments in the molecule H2O2B2I1Os1. + """ + fragmols = detect_fragments(mol_H6O2B2Ne2I1Os1Tl1, verbosity=0) + assert len(fragmols) == 7 + + # check that the first fragment is equal to the molecule H2O2B2I1Os1 + assert np.allclose(fragmols[0].xyz, mol_H2O2B2I1Os1.xyz)