From 44d3fd1f118fc229a7ea8caae0e6bce9df218796 Mon Sep 17 00:00:00 2001 From: Xinzijian Liu Date: Wed, 21 Aug 2024 11:28:54 +0800 Subject: [PATCH] add distance conf filter (#250) ## Summary by CodeRabbit - **New Features** - Introduced a `filters` argument for configuration customization in existing functions. - Implemented new filtering classes for validating atomic configurations based on distance and geometric criteria, enhancing configuration selection options. - **Tests** - Added unit tests for the new filtering classes to ensure robust functionality and validation of atomic configurations. --------- Signed-off-by: zjgemi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dpgen2/entrypoint/args.py | 42 +++ dpgen2/entrypoint/submit.py | 20 ++ .../exploration/render/traj_render_lammps.py | 8 +- dpgen2/exploration/selector/__init__.py | 11 + dpgen2/exploration/selector/conf_filter.py | 27 +- .../selector/distance_conf_filter.py | 329 ++++++++++++++++++ tests/exploration/test_conf_filter.py | 7 +- .../exploration/test_distance_conf_filter.py | 174 +++++++++ 8 files changed, 589 insertions(+), 29 deletions(-) create mode 100644 dpgen2/exploration/selector/distance_conf_filter.py create mode 100644 tests/exploration/test_distance_conf_filter.py diff --git a/dpgen2/entrypoint/args.py b/dpgen2/entrypoint/args.py index e64beeb7..7acb4c33 100644 --- a/dpgen2/entrypoint/args.py +++ b/dpgen2/entrypoint/args.py @@ -19,6 +19,9 @@ from dpgen2.exploration.report import ( conv_styles, ) +from dpgen2.exploration.selector import ( + conf_filter_styles, +) from dpgen2.fp import ( fp_styles, ) @@ -174,6 +177,25 @@ def variant_conf(): ) +def variant_filter(): + doc = "the type of the configuration filter." + var_list = [] + for kk in conf_filter_styles.keys(): + var_list.append( + Argument( + kk, + dict, + conf_filter_styles[kk].args(), + doc="Configuration filter of type %s" % kk, + ) + ) + return Variant( + "type", + var_list, + doc=doc, + ) + + def lmp_args(): doc_config = "Configuration of lmp exploration" doc_max_numb_iter = "Maximum number of iterations per stage" @@ -189,6 +211,7 @@ def lmp_args(): "Then each stage is defined by a list of exploration task groups. " "Each task group is described in :ref:`the task group definition` " ) + doc_filters = "A list of configuration filters" return [ Argument( @@ -227,6 +250,15 @@ def lmp_args(): alias=["configuration"], ), Argument("stages", List[List[dict]], optional=False, doc=doc_stages), + Argument( + "filters", + list, + [], + [variant_filter()], + optional=True, + default=[], + doc=doc_filters, + ), ] @@ -272,6 +304,7 @@ def caly_args(): "Then each stage is defined by a list of exploration task groups. " "Each task group is described in :ref:`the task group definition` " ) + doc_filters = "A list of configuration filters" return [ Argument( @@ -310,6 +343,15 @@ def caly_args(): alias=["configuration"], ), Argument("stages", List[List[dict]], optional=False, doc=doc_stages), + Argument( + "filters", + list, + [], + [variant_filter()], + optional=True, + default=[], + doc=doc_filters, + ), ] diff --git a/dpgen2/entrypoint/submit.py b/dpgen2/entrypoint/submit.py index 5526f501..7e07659a 100644 --- a/dpgen2/entrypoint/submit.py +++ b/dpgen2/entrypoint/submit.py @@ -5,6 +5,9 @@ import os import pickle import re +from copy import ( + deepcopy, +) from pathlib import ( Path, ) @@ -70,7 +73,9 @@ ExplorationScheduler, ) from dpgen2.exploration.selector import ( + ConfFilters, ConfSelectorFrames, + conf_filter_styles, ) from dpgen2.exploration.task import ( CustomizedLmpTemplateTaskGroup, @@ -272,6 +277,17 @@ def make_naive_exploration_scheduler( ) +def get_conf_filters(config): + conf_filters = None + if len(config) > 0: + conf_filters = ConfFilters() + for c in config: + c = deepcopy(c) + conf_filter = conf_filter_styles[c.pop("type")](**c) + conf_filters.add(conf_filter) + return conf_filters + + def make_calypso_naive_exploration_scheduler(config): model_devi_jobs = config["explore"]["stages"] fp_task_max = config["fp"]["task_max"] @@ -279,6 +295,7 @@ def make_calypso_naive_exploration_scheduler(config): fatal_at_max = config["explore"]["fatal_at_max"] convergence = config["explore"]["convergence"] output_nopbc = config["explore"]["output_nopbc"] + conf_filters = get_conf_filters(config["explore"]["filters"]) scheduler = ExplorationScheduler() # report conv_style = convergence.pop("type") @@ -289,6 +306,7 @@ def make_calypso_naive_exploration_scheduler(config): render, report, fp_task_max, + conf_filters, ) for job_ in model_devi_jobs: @@ -329,6 +347,7 @@ def make_lmp_naive_exploration_scheduler(config): fatal_at_max = config["explore"]["fatal_at_max"] convergence = config["explore"]["convergence"] output_nopbc = config["explore"]["output_nopbc"] + conf_filters = get_conf_filters(config["explore"]["filters"]) scheduler = ExplorationScheduler() # report conv_style = convergence.pop("type") @@ -339,6 +358,7 @@ def make_lmp_naive_exploration_scheduler(config): render, report, fp_task_max, + conf_filters, ) sys_configs_lmp = [] diff --git a/dpgen2/exploration/render/traj_render_lammps.py b/dpgen2/exploration/render/traj_render_lammps.py index 5d4cb56b..89a0f7da 100644 --- a/dpgen2/exploration/render/traj_render_lammps.py +++ b/dpgen2/exploration/render/traj_render_lammps.py @@ -64,7 +64,6 @@ def get_confs( type_map: Optional[List[str]] = None, conf_filters: Optional["ConfFilters"] = None, ) -> dpdata.MultiSystems: - del conf_filters # by far does not support conf filters ntraj = len(trajs) traj_fmt = "lammps/dump" ms = dpdata.MultiSystems(type_map=type_map) @@ -74,4 +73,11 @@ def get_confs( ss.nopbc = self.nopbc ss = ss.sub_system(id_selected[ii]) ms.append(ss) + if conf_filters is not None: + ms2 = dpdata.MultiSystems(type_map=type_map) + for s in ms: + s2 = conf_filters.check(s) + if len(s2) > 0: + ms2.append(s2) + ms = ms2 return ms diff --git a/dpgen2/exploration/selector/__init__.py b/dpgen2/exploration/selector/__init__.py index 4ee996fa..cfed094f 100644 --- a/dpgen2/exploration/selector/__init__.py +++ b/dpgen2/exploration/selector/__init__.py @@ -8,3 +8,14 @@ from .conf_selector_frame import ( ConfSelectorFrames, ) +from .distance_conf_filter import ( + BoxLengthFilter, + BoxSkewnessConfFilter, + DistanceConfFilter, +) + +conf_filter_styles = { + "distance": DistanceConfFilter, + "box_skewness": BoxSkewnessConfFilter, + "box_length": BoxLengthFilter, +} diff --git a/dpgen2/exploration/selector/conf_filter.py b/dpgen2/exploration/selector/conf_filter.py index d0a7b9fb..f9fa170e 100644 --- a/dpgen2/exploration/selector/conf_filter.py +++ b/dpgen2/exploration/selector/conf_filter.py @@ -15,23 +15,14 @@ class ConfFilter(ABC): @abstractmethod def check( self, - coords: np.ndarray, - cell: np.ndarray, - atom_types: np.ndarray, - nopbc: bool, + frame: dpdata.System, ) -> bool: """Check if the configuration is valid. Parameters ---------- - coords : numpy.array - The coordinates, numpy array of shape natoms x 3 - cell : numpy.array - The cell tensor. numpy array of shape 3 x 3 - atom_types : numpy.array - The atom types. numpy array of shape natoms - nopbc : bool - If no periodic boundary condition. + frame : dpdata.System + A dpdata.System containing a single frame Returns ------- @@ -62,16 +53,6 @@ def check( natoms = sum(conf["atom_numbs"]) # type: ignore selected_idx = np.arange(conf.get_nframes()) for ff in self._filters: - fsel = np.where( - [ - ff.check( - conf["coords"][ii], - conf["cells"][ii], - conf["atom_types"], - conf.nopbc, - ) - for ii in range(conf.get_nframes()) - ] - )[0] + fsel = np.where([ff.check(conf[ii]) for ii in range(conf.get_nframes())])[0] selected_idx = np.intersect1d(selected_idx, fsel) return conf.sub_system(selected_idx) diff --git a/dpgen2/exploration/selector/distance_conf_filter.py b/dpgen2/exploration/selector/distance_conf_filter.py new file mode 100644 index 00000000..4d5a8c33 --- /dev/null +++ b/dpgen2/exploration/selector/distance_conf_filter.py @@ -0,0 +1,329 @@ +import logging +from copy import ( + deepcopy, +) +from typing import ( + List, +) + +import dargs +import dpdata +import numpy as np +from dargs import ( + Argument, +) + +from . import ( + ConfFilter, +) + +safe_dist_dict = { + "H": 1.2255, + "He": 0.936, + "Li": 1.8, + "Be": 1.56, + "B": 1.32, + "C": 1.32, + "N": 1.32, + "O": 1.32, + "F": 1.26, + "Ne": 1.92, + "Na": 1.595, + "Mg": 1.87, + "Al": 1.87, + "Si": 1.76, + "P": 1.65, + "S": 1.65, + "Cl": 1.65, + "Ar": 2.09, + "K": 2.3, + "Ca": 2.3, + "Sc": 2.0, + "Ti": 2.0, + "V": 2.0, + "Cr": 1.9, + "Mn": 1.95, + "Fe": 1.9, + "Co": 1.9, + "Ni": 1.9, + "Cu": 1.9, + "Zn": 1.9, + "Ga": 2.0, + "Ge": 2.0, + "As": 2.0, + "Se": 2.1, + "Br": 2.1, + "Kr": 2.3, + "Rb": 2.5, + "Sr": 2.5, + "Y": 2.1, + "Zr": 2.1, + "Nb": 2.1, + "Mo": 2.1, + "Tc": 2.1, + "Ru": 2.1, + "Rh": 2.1, + "Pd": 2.1, + "Ag": 2.1, + "Cd": 2.1, + "In": 2.0, + "Sn": 2.0, + "Sb": 2.0, + "Te": 2.0, + "I": 2.0, + "Xe": 2.0, + "Cs": 2.5, + "Ba": 2.8, + "La": 2.5, + "Ce": 2.55, + "Pr": 2.7, + "Nd": 2.8, + "Pm": 2.8, + "Sm": 2.8, + "Eu": 2.8, + "Gd": 2.8, + "Tb": 2.8, + "Dy": 2.8, + "Ho": 2.8, + "Er": 2.6, + "Tm": 2.8, + "Yb": 2.8, + "Lu": 2.8, + "Hf": 2.4, + "Ta": 2.5, + "W": 2.3, + "Re": 2.3, + "Os": 2.3, + "Ir": 2.3, + "Pt": 2.3, + "Au": 2.3, + "Hg": 2.3, + "Tl": 2.3, + "Pb": 2.3, + "Bi": 2.3, + "Po": 2.3, + "At": 2.3, + "Rn": 2.3, + "Fr": 2.9, + "Ra": 2.9, + "Ac": 2.9, + "Th": 2.8, + "Pa": 2.8, + "U": 2.8, + "Np": 2.8, + "Pu": 2.8, + "Am": 2.8, + "Cm": 2.8, + "Cf": 2.3, +} + + +def check_multiples(a, b, c, multiple): + values = [a, b, c] + + for i in range(len(values)): + for j in range(len(values)): + if i != j: + if values[i] > multiple * values[j]: + logging.warning( + f"Value {values[i]} is {multiple} times greater than {values[j]}" + ) + return True + return False + + +class DistanceConfFilter(ConfFilter): + def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0): + self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {} + self.safe_dist_ratio = safe_dist_ratio + + def check( + self, + frame: dpdata.System, + ): + from ase import ( + Atoms, + ) + from ase.build import ( + make_supercell, + ) + + safe_dist = deepcopy(safe_dist_dict) + safe_dist.update(self.custom_safe_dist) + for k in safe_dist: + # bohr -> ang and multiply by a relaxation ratio + safe_dist[k] *= 0.529 / 1.2 * self.safe_dist_ratio + + atom_names = list(safe_dist) + structure = Atoms( + positions=frame["coords"][0], + numbers=[ + atom_names.index(frame["atom_names"][t]) + 1 + for t in frame["atom_types"] + ], + cell=frame["cells"][0], + pbc=(not frame.nopbc), + ) + + P = [[2, 0, 0], [0, 2, 0], [0, 0, 2]] + extended_structure = make_supercell(structure, P) + + coords = extended_structure.positions + symbols = extended_structure.get_chemical_symbols() + + num_atoms = len(coords) + for i in range(num_atoms): + for j in range(i + 1, num_atoms): + dist = extended_structure.get_distance(i, j, mic=True) + type_i = symbols[i] + type_j = symbols[j] + dr = safe_dist[type_i] + safe_dist[type_j] + + if dist < dr: + logging.warning( + f"Dangerous close for {type_i} - {type_j}, {dist:.5f} less than {dr:.5f}" + ) + return False + + return True + + @staticmethod + def args() -> List[dargs.Argument]: + r"""The argument definition of the `ConfFilter`. + + Returns + ------- + arguments: List[dargs.Argument] + List of dargs.Argument defines the arguments of the `ConfFilter`. + """ + + doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element" + doc_safe_dist_ratio = "The ratio multiplied to the safe distance" + return [ + Argument( + "custom_safe_dist", + dict, + optional=True, + default={}, + doc=doc_custom_safe_dist, + ), + Argument( + "safe_dist_ratio", + float, + optional=True, + default=1.0, + doc=doc_safe_dist_ratio, + ), + ] + + +class BoxSkewnessConfFilter(ConfFilter): + def __init__(self, theta=60.0): + self.theta = theta + + def check( + self, + frame: dpdata.System, + ): + from ase import ( + Atoms, + ) + + atom_names = list(safe_dist_dict) + structure = Atoms( + positions=frame["coords"][0], + numbers=[ + atom_names.index(frame["atom_names"][t]) + 1 + for t in frame["atom_types"] + ], + cell=frame["cells"][0], + pbc=(not frame.nopbc), + ) + + cell, _ = structure.get_cell().standard_form() + + if ( + cell[1][0] > np.tan(self.theta / 180.0 * np.pi) * cell[1][1] # type: ignore + or cell[2][0] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore + or cell[2][1] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore + ): + logging.warning("Inclined box") + return False + return True + + @staticmethod + def args() -> List[dargs.Argument]: + r"""The argument definition of the `ConfFilter`. + + Returns + ------- + arguments: List[dargs.Argument] + List of dargs.Argument defines the arguments of the `ConfFilter`. + """ + + doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed" + return [ + Argument( + "theta", + float, + optional=True, + default=60.0, + doc=doc_theta, + ), + ] + + +class BoxLengthFilter(ConfFilter): + def __init__(self, length_ratio=5.0): + self.length_ratio = length_ratio + + def check( + self, + frame: dpdata.System, + ): + from ase import ( + Atoms, + ) + + atom_names = list(safe_dist_dict) + structure = Atoms( + positions=frame["coords"][0], + numbers=[ + atom_names.index(frame["atom_names"][t]) + 1 + for t in frame["atom_types"] + ], + cell=frame["cells"][0], + pbc=(not frame.nopbc), + ) + + cell, _ = structure.get_cell().standard_form() + + a = cell[0][0] # type: ignore + b = cell[1][1] # type: ignore + c = cell[2][2] # type: ignore + + if check_multiples(a, b, c, self.length_ratio): + logging.warning("One side is %s larger than another" % self.length_ratio) + return False + return True + + @staticmethod + def args() -> List[dargs.Argument]: + r"""The argument definition of the `ConfFilter`. + + Returns + ------- + arguments: List[dargs.Argument] + List of dargs.Argument defines the arguments of the `ConfFilter`. + """ + + doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed" + return [ + Argument( + "length_ratio", + float, + optional=True, + default=5.0, + doc=doc_length_ratio, + ), + ] diff --git a/tests/exploration/test_conf_filter.py b/tests/exploration/test_conf_filter.py index 89d096d0..0022e63b 100644 --- a/tests/exploration/test_conf_filter.py +++ b/tests/exploration/test_conf_filter.py @@ -25,10 +25,7 @@ class FooFilter(ConfFilter): def check( self, - coords: np.array, - cell: np.array, - atom_types: np.array, - nopbc: bool, + frame: dpdata.System, ) -> bool: return True @@ -38,7 +35,7 @@ class faked_filter: myret = [True] @classmethod - def faked_check(cls, cc, ce, at, np): + def faked_check(cls, frame): cls.myiter += 1 cls.myiter = cls.myiter % len(cls.myret) return cls.myret[cls.myiter] diff --git a/tests/exploration/test_distance_conf_filter.py b/tests/exploration/test_distance_conf_filter.py new file mode 100644 index 00000000..98b7ba4e --- /dev/null +++ b/tests/exploration/test_distance_conf_filter.py @@ -0,0 +1,174 @@ +import os +import unittest + +import dpdata +import numpy as np + +from dpgen2.exploration.selector import ( + BoxLengthFilter, + BoxSkewnessConfFilter, + DistanceConfFilter, +) + +from .context import ( + dpgen2, +) + +POSCAR_valid = """ Er +1.0 + 7.00390434172054 0.000000000000000E+000 0.000000000000000E+000 + -3.50195193887670 6.06555921954188 0.000000000000000E+000 + 4.695904554609645E-007 8.133544916878595E-007 6.21991417332993 + Er Th Pd + 3 3 5 +Direct + 0.404315576593774 0.000000000000000E+000 0.916328931175151 + 0.000000000000000E+000 0.404315576593774 0.916328931175151 + 0.595684423406226 0.595684423406226 0.916328931175151 + 0.308657693786501 0.000000000000000E+000 0.431543321200265 + 0.000000000000000E+000 0.308657693786501 0.431543321200265 + 0.691342306213499 0.691342306213499 0.431543321200265 + 0.333299994468689 0.666700005531311 0.181639126706554 + 0.666700005531311 0.333299994468689 0.181639126706554 + 0.333299994468689 0.666700005531311 0.653715146968972 + 0.666700005531311 0.333299994468689 0.653715146968972 + 0.000000000000000E+000 0.000000000000000E+000 0.767627989523288 +""" + + +POSCAR_tilt = """POSCAR file written by OVITO Basic 3.10.6 +1 +9.9156076829 0.0 0.0 +7.9377192882e-07 10.5138279814 0.0 +11.9805 1.9108776119e-08 6.3054328214 +Re B Au +3 2 4 +Direct +0.4552610161 0.1437969637 0.9105503417 +0.4552610161 0.8562030363 0.9105503417 +0.9406309225 4.4367368569e-14 0.8634531812 +0.1522944187 0.5 0.9114412858 +0.5212391304 -3.0871058636e-12 0.1011293077 +0.8742903123 0.6505092137 0.9508321387 +0.8742903123 0.3494907863 0.9508321387 +0.5509009668 0.6521996574 0.004767873 +0.5509009668 0.3478003426 0.004767873 +""" + + +POSCAR_long = """POSCAR file written by OVITO Basic 3.10.6 +1 +11.8987292195 0.0 0.0 +9.5252631458e-07 12.6165935777 0.0 +0.0 2.2930531343e-08 2.0177385028 +Re B Au +3 2 4 +Direct +0.4552610161 0.1437969637 0.9105503417 +0.4552610161 0.8562030363 0.9105503417 +0.9406309225 4.4367368568e-14 0.8634531812 +0.1522944187 0.5 0.9114412858 +0.5212391304 -3.0871058636e-12 0.1011293077 +0.8742903123 0.6505092137 0.9508321387 +0.8742903123 0.3494907863 0.9508321387 +0.5509009668 0.6521996574 0.004767873 +0.5509009668 0.3478003426 0.004767873 +""" + + +POSCAR_close = """POSCAR file written by OVITO Basic 3.10.6 +1 +9.9156076829 0.0 0.0 +7.9377192882e-07 10.5138279814 0.0 +3.0151272179 1.9108776119e-08 6.3054328214 +Re B Au +3 2 4 +Direct +0.4552610161 0.1437969637 0.9105503417 +0.5078872031 0.9988722905 1.069143737 +0.9406309225 4.4367368569e-14 0.8634531812 +0.1522944187 0.5 0.9114412858 +0.5212391304 -3.0871058636e-12 0.1011293077 +0.8742903123 0.6505092137 0.9508321387 +0.8742903123 0.3494907863 0.9508321387 +0.5509009668 0.6521996574 0.004767873 +0.5509009668 0.3478003426 0.004767873 +""" + + +class TestBoxSkewnessConfFilter(unittest.TestCase): + def setUp(self): + with open("POSCAR_valid", "w") as f: + f.write(POSCAR_valid) + with open("POSCAR_tilt", "w") as f: + f.write(POSCAR_tilt) + + def test_valid(self): + system = dpdata.System("POSCAR_valid", fmt="poscar") + distance_conf_filter = BoxSkewnessConfFilter() + valid = distance_conf_filter.check(system) + self.assertTrue(valid) + + def test_invalid(self): + system = dpdata.System("POSCAR_tilt", fmt="poscar") + distance_conf_filter = BoxSkewnessConfFilter() + valid = distance_conf_filter.check(system) + self.assertFalse(valid) + + def tearDown(self): + if os.path.isfile("POSCAR_valid"): + os.remove("POSCAR_valid") + if os.path.isfile("POSCAR_tilt"): + os.remove("POSCAR_tilt") + + +class TestBoxLengthConfFilter(unittest.TestCase): + def setUp(self): + with open("POSCAR_valid", "w") as f: + f.write(POSCAR_valid) + with open("POSCAR_long", "w") as f: + f.write(POSCAR_long) + + def test_valid(self): + system = dpdata.System("POSCAR_valid", fmt="poscar") + distance_conf_filter = BoxLengthFilter() + valid = distance_conf_filter.check(system) + self.assertTrue(valid) + + def test_invalid(self): + system = dpdata.System("POSCAR_long", fmt="poscar") + distance_conf_filter = BoxLengthFilter() + valid = distance_conf_filter.check(system) + self.assertFalse(valid) + + def tearDown(self): + if os.path.isfile("POSCAR_valid"): + os.remove("POSCAR_valid") + if os.path.isfile("POSCAR_long"): + os.remove("POSCAR_long") + + +class TestDistanceConfFilter(unittest.TestCase): + def setUp(self): + with open("POSCAR_valid", "w") as f: + f.write(POSCAR_valid) + with open("POSCAR_close", "w") as f: + f.write(POSCAR_close) + + def test_valid(self): + system = dpdata.System("POSCAR_valid", fmt="poscar") + distance_conf_filter = DistanceConfFilter() + valid = distance_conf_filter.check(system) + self.assertTrue(valid) + + def test_invalid(self): + system = dpdata.System("POSCAR_close", fmt="poscar") + distance_conf_filter = DistanceConfFilter() + valid = distance_conf_filter.check(system) + self.assertFalse(valid) + + def tearDown(self): + if os.path.isfile("POSCAR_valid"): + os.remove("POSCAR_valid") + if os.path.isfile("POSCAR_close"): + os.remove("POSCAR_close")