diff --git a/dpdata/driver.py b/dpdata/driver.py index 84419cfe..29ae2703 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -1,5 +1,5 @@ """Driver plugin system.""" -from typing import Callable +from typing import Callable, List, Union from .plugin import Plugin from abc import ABC, abstractmethod @@ -78,3 +78,64 @@ def label(self, data: dict) -> dict: labeled data with energies and forces """ return NotImplemented + + +@Driver.register("hybrid") +class HybridDriver(Driver): + """Hybrid driver, with mixed drivers. + + Parameters + ---------- + drivers : list[dict, Driver] + list of drivers or drivers dict. For a dict, it should + contain `type` as the name of the driver, and others + are arguments of the driver. + + Raises + ------ + TypeError + The value of `drivers` is not a dict or `Driver`. + + Examples + -------- + >>> driver = HybridDriver([ + ... {"type": "sqm", "qm_theory": "DFTB3"}, + ... {"type": "dp", "dp": "frozen_model.pb"}, + ... ]) + This driver is the hybrid of SQM and DP. + """ + def __init__(self, drivers: List[Union[dict, Driver]]) -> None: + self.drivers = [] + for driver in drivers: + if isinstance(driver, Driver): + self.drivers.append(driver) + elif isinstance(driver, dict): + type = driver["type"] + del driver["type"] + self.drivers.append(Driver.get_driver(type)(**driver)) + else: + raise TypeError("driver should be Driver or dict") + + def label(self, data: dict) -> dict: + """Label a system data. + + Energies and forces are the sum of those of each driver. + + Parameters + ---------- + data : dict + data with coordinates and atom types + + Returns + ------- + dict + labeled data with energies and forces + """ + for ii, driver in enumerate(self.drivers): + lb_data = driver.label(data.copy()) + if ii == 0: + labeled_data = lb_data.copy() + else: + labeled_data['energies'] += lb_data ['energies'] + labeled_data['forces'] += lb_data ['forces'] + return labeled_data diff --git a/tests/test_predict.py b/tests/test_predict.py index 1cb816cc..c5fe5d76 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -16,6 +16,17 @@ def label(self, data): return data +@dpdata.driver.Driver.register("one") +class ZeroDriver(dpdata.driver.Driver): + def label(self, data): + nframes = data['coords'].shape[0] + natoms = data['coords'].shape[1] + data['energies'] = np.ones((nframes,)) + data['forces'] = np.ones((nframes, natoms, 3)) + data['virials'] = np.ones((nframes, 3, 3)) + return data + + class TestPredict(unittest.TestCase, CompLabeledSys): def setUp (self) : ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md', @@ -32,3 +43,29 @@ def setUp (self) : self.e_places = 6 self.f_places = 6 self.v_places = 6 + + +class TestHybridDriver(unittest.TestCase, CompLabeledSys): + """Test HybridDriver.""" + def setUp(self) : + ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md', + fmt = 'deepmd/raw', + type_map = ['O', 'H']) + self.system_1 = ori_sys.predict([ + {"type": "one"}, + {"type": "one"}, + {"type": "one"}, + {"type": "zero"}, + ], + driver="hybrid") + # sum is 3 + self.system_2 = dpdata.LabeledSystem('poscars/deepmd.h2o.md', + fmt = 'deepmd/raw', + type_map = ['O', 'H']) + for pp in ('energies', 'forces'): + self.system_2.data[pp][:] = 3. + + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 6