Skip to content

Commit

Permalink
add HybridDriver (deepmodeling#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored May 23, 2022
1 parent 31292fd commit c0bb798
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
63 changes: 62 additions & 1 deletion dpdata/driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Driver plugin system."""
from typing import Callable
from typing import Callable, List, Union
from .plugin import Plugin
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -78,3 +78,64 @@ def label(self, data: dict) -> dict:
labeled data with energies and forces
"""
return NotImplemented


@Driver.register("hybrid")
class HybridDriver(Driver):
"""Hybrid driver, with mixed drivers.
Parameters
----------
drivers : list[dict, Driver]
list of drivers or drivers dict. For a dict, it should
contain `type` as the name of the driver, and others
are arguments of the driver.
Raises
------
TypeError
The value of `drivers` is not a dict or `Driver`.
Examples
--------
>>> driver = HybridDriver([
... {"type": "sqm", "qm_theory": "DFTB3"},
... {"type": "dp", "dp": "frozen_model.pb"},
... ])
This driver is the hybrid of SQM and DP.
"""
def __init__(self, drivers: List[Union[dict, Driver]]) -> None:
self.drivers = []
for driver in drivers:
if isinstance(driver, Driver):
self.drivers.append(driver)
elif isinstance(driver, dict):
type = driver["type"]
del driver["type"]
self.drivers.append(Driver.get_driver(type)(**driver))
else:
raise TypeError("driver should be Driver or dict")

def label(self, data: dict) -> dict:
"""Label a system data.
Energies and forces are the sum of those of each driver.
Parameters
----------
data : dict
data with coordinates and atom types
Returns
-------
dict
labeled data with energies and forces
"""
for ii, driver in enumerate(self.drivers):
lb_data = driver.label(data.copy())
if ii == 0:
labeled_data = lb_data.copy()
else:
labeled_data['energies'] += lb_data ['energies']
labeled_data['forces'] += lb_data ['forces']
return labeled_data
37 changes: 37 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ def label(self, data):
return data


@dpdata.driver.Driver.register("one")
class ZeroDriver(dpdata.driver.Driver):
def label(self, data):
nframes = data['coords'].shape[0]
natoms = data['coords'].shape[1]
data['energies'] = np.ones((nframes,))
data['forces'] = np.ones((nframes, natoms, 3))
data['virials'] = np.ones((nframes, 3, 3))
return data


class TestPredict(unittest.TestCase, CompLabeledSys):
def setUp (self) :
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
Expand All @@ -32,3 +43,29 @@ def setUp (self) :
self.e_places = 6
self.f_places = 6
self.v_places = 6


class TestHybridDriver(unittest.TestCase, CompLabeledSys):
"""Test HybridDriver."""
def setUp(self) :
ori_sys = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
fmt = 'deepmd/raw',
type_map = ['O', 'H'])
self.system_1 = ori_sys.predict([
{"type": "one"},
{"type": "one"},
{"type": "one"},
{"type": "zero"},
],
driver="hybrid")
# sum is 3
self.system_2 = dpdata.LabeledSystem('poscars/deepmd.h2o.md',
fmt = 'deepmd/raw',
type_map = ['O', 'H'])
for pp in ('energies', 'forces'):
self.system_2.data[pp][:] = 3.

self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 6

0 comments on commit c0bb798

Please sign in to comment.