Skip to content

Commit

Permalink
Merge pull request #183 from umesh-timalsina/175-core-equality-checks
Browse files Browse the repository at this point in the history
Rework equality checks and representation of core types
  • Loading branch information
mattwthompson authored Nov 19, 2019
2 parents 9977066 + 9f900c0 commit 7a4083b
Show file tree
Hide file tree
Showing 30 changed files with 477 additions and 357 deletions.
35 changes: 35 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
trigger:
- 175-core-equality-checks
pr:
- master

jobs:
- job: TestsForTopology
pool:
vmImage: 'ubuntu-latest'
steps:
- bash: echo "##vso[task.prependpath]$CONDA/bin"
displayName: Add conta to path

- bash: |
conda config --set always_yes yes --set changeps1 no
conda config --add channels omnia
conda config --add channels janschulz
conda config --add channels conda-forge
conda config --add channels mosdef
displayName: Add relavent channels
- bash: |
conda create -n test-environment
source activate test-environment
conda install --yes --file requirements.txt
conda install pytest
pip install -e .
displayName: Install requirements, Install branch
- bash: |
source activate test-environment
python -m pytest -v topology
displayName: Run Tests
16 changes: 0 additions & 16 deletions topology/core/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@ def __init__(self, connection_members=[], connection_type=None, name="Angle"):
super(Angle, self).__init__(connection_members=connection_members,
connection_type=connection_type, name=name)

def __eq__(self, other):
return hash(self) == hash(other)

def __hash__(self):
if self.connection_type:
return hash(
tuple(
(
self.name,
self.connection_type,
tuple(self.connection_members),
)
)
)
return hash(tuple(self.connection_members))


def _validate_three_partners(connection_members):
"""Ensure 3 partners are involved in Angle"""
Expand Down
30 changes: 20 additions & 10 deletions topology/core/angle_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from topology.core.potential import Potential
from topology.exceptions import TopologyError
from topology.utils.decorators import confirm_dict_existence
from topology.utils._constants import ANGLE_TYPE_DICT


class AngleType(Potential):
Expand All @@ -18,6 +20,8 @@ class AngleType(Potential):
independent vars : set of str
See `Potential` documentation for more information
member_types : list of topology.AtomType.name (str)
topology: topology.core.Topology, the topology of which this angle_type is a part of, default=None
set_ref: (str), the string name of the bookkeeping set in topology class.
Notes
----
Expand All @@ -30,45 +34,51 @@ def __init__(self,
expression='0.5 * k * (theta-theta_eq)**2',
parameters=None,
independent_variables=None,
member_types=None):
member_types=None,
topology=None):
if parameters is None:
parameters = {
'k': 1000 * u.Unit('kJ / (deg**2)'),
'theta_eq': 180 * u.deg
}
'k': 1000 * u.Unit('kJ / (deg**2)'),
'theta_eq': 180 * u.deg
}
if independent_variables is None:
independent_variables = {'theta'}

if member_types is None:
member_types = list()

super(AngleType, self).__init__(name=name, expression=expression,
parameters=parameters, independent_variables=independent_variables)

parameters=parameters, independent_variables=independent_variables,
topology=topology)
self._member_types = _validate_three_member_type_names(member_types)
self._set_ref = ANGLE_TYPE_DICT

@property
def set_ref(self):
return self._set_ref

@property
def member_types(self):
return self._member_types

@member_types.setter
@confirm_dict_existence
def member_types(self, val):
if self.member_types != val:
warnings.warn("Changing an AngleType's constituent "
"member types: {} to {}".format(self.member_types, val))
"member types: {} to {}".format(self.member_types, val))
self._member_types = _validate_three_member_type_names(val)

def __repr__(self):
return "<AngleType {}, id {}>".format(self.name, id(self))


def _validate_three_member_type_names(types):
"""Ensure 3 partners are involved in AngleType"""
if len(types) != 3 and len(types) != 0:
raise TopologyError("Trying to create an AngleType "
"with {} constituent types". format(len(types)))
"with {} constituent types".format(len(types)))
if not all([isinstance(t, str) for t in types]):
raise TopologyError("Types passed to AngleType "
"need to be strings corresponding to AtomType names")

return types

26 changes: 22 additions & 4 deletions topology/core/atom_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
import unyt as u

from topology.utils.testing import allclose
from topology.core.potential import Potential
from topology.utils.misc import unyt_to_hashable
from topology.utils.decorators import confirm_dict_existence
from topology.utils._constants import ATOM_TYPE_DICT


class AtomType(Potential):
Expand Down Expand Up @@ -45,6 +46,8 @@ class AtomType(Potential):
Set of other atom types that this atom type overrides
definition : str
SMARTS string defining this atom type
topology: topology.core.Topology, the topology of which this atom_type is a part of, default=None
set_ref: (str), the string name of the bookkeeping set in topology class.
"""

Expand All @@ -56,7 +59,8 @@ def __init__(self,
parameters=None,
independent_variables=None,
atomclass='', doi='', overrides=None, definition='',
description=''):
description='',
topology=None):
if parameters is None:
parameters = {'sigma': 0.3 * u.nm,
'epsilon': 0.3 * u.Unit('kJ')}
Expand All @@ -70,22 +74,28 @@ def __init__(self,
name=name,
expression=expression,
parameters=parameters,
independent_variables=independent_variables)
independent_variables=independent_variables,
topology=topology)
self._mass = _validate_mass(mass)
self._charge = _validate_charge(charge)
self._atomclass = _validate_str(atomclass)
self._doi = _validate_str(doi)
self._overrides = _validate_set(overrides)
self._description = _validate_str(description)
self._definition = _validate_str(definition)

self._set_ref = ATOM_TYPE_DICT
self._validate_expression_parameters()

@property
def set_ref(self):
return self._set_ref

@property
def charge(self):
return self._charge

@charge.setter
@confirm_dict_existence
def charge(self, val):
self._charge = _validate_charge(val)

Expand All @@ -94,6 +104,7 @@ def mass(self):
return self._mass

@mass.setter
@confirm_dict_existence
def mass(self, val):
self._mass = _validate_mass(val)

Expand All @@ -102,6 +113,7 @@ def atomclass(self):
return self._atomclass

@atomclass.setter
@confirm_dict_existence
def atomclass(self, val):
self._atomclass = val

Expand All @@ -110,6 +122,7 @@ def doi(self):
return self._doi

@doi.setter
@confirm_dict_existence
def doi(self, doi):
self._doi = _validate_str(doi)

Expand All @@ -118,6 +131,7 @@ def overrides(self):
return self._overrides

@overrides.setter
@confirm_dict_existence
def overrides(self, overrides):
self._overrides = _validate_set(overrides)

Expand All @@ -126,6 +140,7 @@ def description(self):
return self._description

@description.setter
@confirm_dict_existence
def description(self, description):
self._description = _validate_str(description)

Expand All @@ -134,6 +149,7 @@ def definition(self):
return self._definition

@definition.setter
@confirm_dict_existence
def definition(self, definition):
self._definition = _validate_str(definition)

Expand Down Expand Up @@ -185,11 +201,13 @@ def _validate_mass(mass):

return mass


def _validate_str(val):
if not isinstance(val, str):
raise ValueError("Passed value {} is not a string".format(val))
return val


def _validate_set(val):
if not isinstance(val, set):
raise ValueError("Passed value {} is not a set".format(val))
Expand Down
18 changes: 1 addition & 17 deletions topology/core/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,19 @@ class Bond(Connection):

def __init__(self, connection_members=None, connection_type=None, name="Bond"):
if connection_members is None:
connection_members = list()
connection_members = tuple()
connection_members = _validate_two_partners(connection_members)
connection_type = _validate_bondtype(connection_type)

super(Bond, self).__init__(connection_members=connection_members,
connection_type=connection_type, name=name)

def __eq__(self, other):
return hash(self) == hash(other)

def __hash__(self):
if self.connection_type:
return hash(
tuple(
(
self.name,
self.connection_type,
tuple(self.connection_members),
)
)
)
return hash(tuple(self.connection_members))

def _validate_two_partners(connection_members):
"""Ensure 2 partners are involved in Bond"""
if len(connection_members) != 2:
raise TopologyError("Trying to create a Bond "
"with {} connection members". format(len(connection_members)))

return connection_members


Expand Down
29 changes: 20 additions & 9 deletions topology/core/bond_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unyt as u
import warnings

from topology.core.potential import Potential
from topology.utils.decorators import confirm_dict_existence
from topology.exceptions import TopologyError
from topology.utils._constants import BOND_TYPE_DICT


class BondType(Potential):
Expand Down Expand Up @@ -29,45 +32,53 @@ def __init__(self,
expression='0.5 * k * (r-r_eq)**2',
parameters=None,
independent_variables=None,
member_types=None):
member_types=None,
topology=None,
set_ref='bond_type_set'):
if parameters is None:
parameters = {
'k': 1000 * u.Unit('kJ / (nm**2)'),
'r_eq': 0.14 * u.nm
}
'k': 1000 * u.Unit('kJ / (nm**2)'),
'r_eq': 0.14 * u.nm
}
if independent_variables is None:
independent_variables = {'r'}

if member_types is None:
member_types = list()

super(BondType, self).__init__(name=name, expression=expression,
parameters=parameters, independent_variables=independent_variables)

parameters=parameters, independent_variables=independent_variables,
topology=topology)
self._set_ref = BOND_TYPE_DICT
self._member_types = _validate_two_member_type_names(member_types)

@property
def set_ref(self):
return self._set_ref

@property
def member_types(self):
return self._member_types

@member_types.setter
@confirm_dict_existence
def member_types(self, val):
if self.member_types != val:
warnings.warn("Changing a BondType's constituent "
"member types: {} to {}".format(self.member_types, val))
"member types: {} to {}".format(self.member_types, val))
self._member_types = _validate_two_member_type_names(val)

def __repr__(self):
return "<BondType {}, id {}>".format(self.name, id(self))


def _validate_two_member_type_names(types):
"""Ensure 2 partners are involved in BondType"""
if len(types) != 2 and len(types) != 0:
raise TopologyError("Trying to create a BondType "
"with {} constituent types". format(len(types)))
"with {} constituent types".format(len(types)))
if not all([isinstance(t, str) for t in types]):
raise TopologyError("Types passed to BondType "
"need to be strings corresponding to AtomType names")

return types

Loading

0 comments on commit 7a4083b

Please sign in to comment.