Skip to content

Commit

Permalink
Merge pull request #14 from mattwthompson/override-eq
Browse files Browse the repository at this point in the history
Define some class equivalences
  • Loading branch information
mattwthompson authored Mar 24, 2019
2 parents 8d1b26d + 474fd92 commit 807f3d4
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 3 deletions.
7 changes: 7 additions & 0 deletions topology/core/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, connection_members=[], connection_type=None):
super(Angle, self).__init__(connection_members=connection_members,
connection_type=connection_type)

def __eq__(self, other):
if not self.connection_members == other.connection_members:
return False
if not self.connection_type == other.connection_type:
return False
return True


def _validate_three_partners(connection_members):
"""Ensure 3 partners are involved in Angle"""
Expand Down
6 changes: 6 additions & 0 deletions topology/core/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def __init__(self, connection_members=[], connection_type=None):
super(Bond, self).__init__(connection_members=connection_members,
connection_type=connection_type)

def __eq__(self, other):
if not self.connection_members == other.connection_members:
return False
if not self.connection_type == other.connection_type:
return False
return True

def _validate_two_partners(connection_members):
"""Ensure 2 partners are involved in Bond"""
Expand Down
18 changes: 18 additions & 0 deletions topology/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import unyt as u
from topology.testing.utils import allclose


def _validate_lengths(lengths):
Expand Down Expand Up @@ -140,3 +141,20 @@ def get_unit_vectors(self):
def __repr__(self):
return "Box(a={}, b={}, c={}, alpha={}, beta={}, gamma={})"\
.format(*self._lengths, *self._angles)

def __eq__(self, other):
"""Compare two boxes for equivalence."""

if self is other:
return True

if not isinstance(other, Box):
return False

if not allclose(self.lengths, other.lengths):
return False

if not allclose(self.angles, other.angles):
return False

return True
19 changes: 17 additions & 2 deletions topology/core/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
import unyt as u

from topology.core.atom_type import AtomType
from topology.testing.utils import allclose


class Site(object):
"""A general site."""

def __init__(self,
name,
name='Site',
position=None,
charge=None,
mass=None,
element=None,
atom_type=None):
self.name = str(name)
if name is not None:
self.name = str(name)
if position is None:
self.position = u.nm * np.zeros(3)
else:
Expand Down Expand Up @@ -83,6 +85,19 @@ def atom_type(self, val):
val = _validate_atom_type(val)
self._atom_type = val

def __eq__(self, other):
if not allclose(self.position, other.position):
return False
if not allclose(self.charge, other.charge, atol=1e-22):
return False
if self.atom_type != other.atom_type:
return False

return True

def __hash__(self):
return id(self)

def __repr__(self):
return "<Site {}, id {}>".format(self.name, id(self))

Expand Down
29 changes: 29 additions & 0 deletions topology/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from topology.core.bond_type import BondType
from topology.core.angle_type import AngleType
from topology.exceptions import TopologyError
from topology.testing.utils import allclose


class Topology(object):
Expand Down Expand Up @@ -240,3 +241,31 @@ def __repr__(self):
descr.append('id: {}>'.format(id(self)))

return ''.join(descr)

def __eq__(self, other):
"""Compare a topology for equivalence."""

if self is other:
return True

if not isinstance(other, Topology):
return False

if self.name != other.name:
return False

if self.n_sites != other.n_sites:
return False

for (con1, con2) in zip(self.connections, other.connections):
if con1 != con2:
return False

for (site1, site2) in zip(self.sites, other.sites):
if site1 != site2:
return False

if self.box != other.box:
return False

return True
3 changes: 3 additions & 0 deletions topology/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

def allclose(a, b, rtol=None, atol=None):
"""Compare two unyt arrays."""
if a is None and b is None:
return True

if a.units != b.units:
common_unit = _infer_common_unit(a, b)
else:
Expand Down
4 changes: 4 additions & 0 deletions topology/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def mass(self):
def box(self):
return Box(lengths=u.nm*np.ones(3))

@pytest.fixture
def top(self):
return Topology(name='mytop')

@pytest.fixture
def topology_site(self):
def _topology(sites=1):
Expand Down
15 changes: 15 additions & 0 deletions topology/tests/test_box.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import pytest
import numpy as np
import unyt as u
Expand Down Expand Up @@ -81,3 +83,16 @@ def test_scaled_vectors(self):
test_vectors = (test_vectors.T * box.lengths).T
assert allclose(vectors, test_vectors, atol=u.nm*1e-3)
assert vectors.units == u.nm

def test_eq(self, box):
assert box == box

def test_eq_bad_lengths(self, box):
diff_lengths = deepcopy(box)
diff_lengths.lengths = u.nm * [5.0, 5.0, 5.0]
assert box != diff_lengths

def test_eq_bad_angles(self, box):
diff_angles = deepcopy(box)
diff_angles.angles = u.degree * [90, 90, 120]
assert box != diff_angles
76 changes: 75 additions & 1 deletion topology/tests/test_topology.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from copy import deepcopy

import numpy as np
import pytest
import unyt as u
import parmed as pmd

from topology import *
from topology.external.convert_parmed import from_parmed

from topology.tests.base_test import BaseTest
from topology.testing.utils import allclose
from topology.tests.base_test import BaseTest
from topology.utils.io import get_fn


class TestTopology(BaseTest):
Expand Down Expand Up @@ -55,6 +61,75 @@ def test_positions_dtype(self):
assert top.positions().units == u.nm
assert isinstance(top.positions(), u.unyt_array)

def test_eq_types(self, top, box):
assert top != box

diff_name = deepcopy(top)
diff_name.name = 'othertop'
assert top != diff_name

def test_eq_sites(self, top, charge):
ref = deepcopy(top)
wrong_n_sites = deepcopy(top)
assert top == wrong_n_sites
ref.add_site(Site())
assert ref != wrong_n_sites

ref = deepcopy(top)
wrong_position = deepcopy(top)
ref.add_site(Site(position=u.nm*[0, 0, 0]))
wrong_position.add_site(Site(position=u.nm*[1, 1, 1]))
assert top != wrong_position

ref = deepcopy(top)
wrong_charge = deepcopy(top)
ref.add_site(Site(charge=charge))
wrong_charge.add_site(Site(charge=-1*charge))
assert ref != wrong_charge

ref = deepcopy(top)
wrong_atom_type = deepcopy(top)
ref.add_site(Site(atom_type=AtomType(expression='epsilon*sigma')))
wrong_atom_type.add_site(Site(atom_type=AtomType(expression='sigma')))
assert ref != wrong_atom_type

def test_eq_bonds(self):
ref = pmd.load_file(get_fn('ethane.top'),
xyz=get_fn('ethane.gro'))

missing_bond = deepcopy(ref)
missing_bond.bonds[0].delete()

assert ref != missing_bond

bad_bond_type = deepcopy(ref)
bad_bond_type.bond_types[0].k = 22

assert ref != bad_bond_type

def test_eq_angles(self):
ref = pmd.load_file(get_fn('ethane.top'),
xyz=get_fn('ethane.gro'))

missing_angle = deepcopy(ref)
missing_angle.angles[0].delete()

assert ref != missing_angle

bad_angle_type = deepcopy(ref)
bad_angle_type.angle_types[0].k = 22

assert ref != bad_angle_type

def test_eq_overall(self):
ref = pmd.load_file(get_fn('ethane.top'),
xyz=get_fn('ethane.gro'))

top1 = from_parmed(ref)
top2 = from_parmed(ref)

assert top1 == top2

def test_top_update(self):
top = Topology()
top.update_top()
Expand Down Expand Up @@ -198,4 +273,3 @@ def test_angle_angletype_update(self):
assert len(top.angle_types) == 1
assert len(top.angle_type_expressions) == 1
assert len(top.atom_type_expressions) == 2

0 comments on commit 807f3d4

Please sign in to comment.