Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to identify connections in a snapshot #76

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions cmeutils/gsd_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings
from tempfile import NamedTemporaryFile

import freud
import gsd.hoomd
import hoomd
import networkx as nx
import numpy as np
from boltons.setutils import IndexedSet

from cmeutils.geometry import moit

Expand Down Expand Up @@ -374,3 +377,248 @@
snap.bonds.group = bonds
newt.append(snap)
print(f"XML data written to {gsdfile}")


def identify_snapshot_connections(snapshot):
"""Identify angle and dihedral connections in a snapshot from bonds.

Parameters
----------
snapshot : gsd.hoomd.Frame
The snapshot to read in.

Returns
-------
gsd.hoomd.Frame
The snapshot with angle and dihedral information added.
"""
if snapshot.bonds.N == 0:
warnings.warn(
"No bonds found in snapshot, hence, no angles or "
"dihedrals will be identified."
)
return snapshot
bond_groups = snapshot.bonds.group
connection_matches = _find_connections(bond_groups)

if connection_matches["angles"]:
_fill_connection_info(
snapshot=snapshot,
connections=connection_matches["angles"],
type_="angles",
)
if connection_matches["dihedrals"]:
_fill_connection_info(
snapshot=snapshot,
connections=connection_matches["dihedrals"],
type_="dihedrals",
)
return snapshot


def _fill_connection_info(snapshot, connections, type_):
p_types = snapshot.particles.types
p_typeid = snapshot.particles.typeid
_connection_types = []
_connection_typeid = []
for conn in connections:
conn_sites = [p_types[p_typeid[i]] for i in conn]
sorted_conn_sites = _sort_connection_by_name(conn_sites, type_)
type = "-".join(sorted_conn_sites)
# check if type not in angle_types and types_inv not in angle_types:
if type not in _connection_types:
_connection_types.append(type)
_connection_typeid.append(
max(_connection_typeid) + 1 if _connection_typeid else 0
)
else:
_connection_typeid.append(_connection_types.index(type))

if type_ == "angles":
snapshot.angles.N = len(connections)
snapshot.angles.M = 3
snapshot.angles.group = connections
snapshot.angles.types = _connection_types
snapshot.angles.typeid = _connection_typeid
elif type_ == "dihedrals":
snapshot.dihedrals.N = len(connections)
snapshot.dihedrals.M = 4
snapshot.dihedrals.group = connections
snapshot.dihedrals.types = _connection_types
snapshot.dihedrals.typeid = _connection_typeid


# The following functions are obtained from gmso/utils/connectivity.py with
# minor modifications.
def _sort_connection_by_name(conn_sites, type_):
if type_ == "angles":
site1, site3 = sorted([conn_sites[0], conn_sites[2]])
return [site1, conn_sites[1], site3]
elif type_ == "dihedrals":
site1, site2, site3, site4 = conn_sites
if site2 > site3 or (site2 == site3 and site1 > site4):
return [site4, site3, site2, site1]
else:
return [site1, site2, site3, site4]


def _find_connections(bonds):
"""Identify all possible connections within a topology."""
compound = nx.Graph()

for b in bonds:
compound.add_edge(b[0], b[1])

compound_line_graph = nx.line_graph(compound)

angle_matches = _detect_connections(compound_line_graph, type_="angle")
dihedral_matches = _detect_connections(
compound_line_graph, type_="dihedral"
)

return {
"angles": angle_matches,
"dihedrals": dihedral_matches,
}


def _detect_connections(compound_line_graph, type_="angle"):
EDGES = {
"angle": ((0, 1),),
"dihedral": ((0, 1), (1, 2)),
}

connection = nx.Graph()
for edge in EDGES[type_]:
assert len(edge) == 2, "Edges should be of length 2"
connection.add_edge(edge[0], edge[1])

matcher = nx.algorithms.isomorphism.GraphMatcher(
compound_line_graph, connection
)

formatter_fns = {
"angle": _format_subgraph_angle,
"dihedral": _format_subgraph_dihedral,
}

conn_matches = IndexedSet()
for m in matcher.subgraph_isomorphisms_iter():
new_connection = formatter_fns[type_](m)
conn_matches.add(new_connection)
if conn_matches:
conn_matches = _trim_duplicates(conn_matches)

# Do more sorting of individual connection
sorted_conn_matches = list()
for match in conn_matches:
if match[0] < match[-1]:
sorted_conn = match
else:
sorted_conn = match[::-1]
sorted_conn_matches.append(list(sorted_conn))

# Final sorting the whole list
if type_ == "angle":
return sorted(
sorted_conn_matches,
key=lambda angle: (
angle[1],
angle[0],
angle[2],
),
)
elif type_ == "dihedral":
return sorted(
sorted_conn_matches,
key=lambda dihedral: (
dihedral[1],
dihedral[2],
dihedral[0],
dihedral[3],
),
)


def _get_sorted_by_n_connections(m):
"""Return sorted by n connections for the matching graph."""
small = nx.Graph()
for k, v in m.items():
small.add_edge(k[0], k[1])
return sorted(small.adj, key=lambda x: len(small[x])), small


def _format_subgraph_angle(m):
"""Format the angle subgraph.

Since we are matching compound line graphs,
back out the actual nodes, not just the edges

Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)

Returns
-------
connection : list of nodes, in order of bonding
(start, middle, end)
"""
(sort_by_n_connections, _) = _get_sorted_by_n_connections(m)
ends = sorted([sort_by_n_connections[0], sort_by_n_connections[1]])
middle = sort_by_n_connections[2]
return (
ends[0],
middle,
ends[1],
)


def _format_subgraph_dihedral(m):
"""Format the dihedral subgraph.

Since we are matching compound line graphs,
back out the actual nodes, not just the edges

Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)
top : gmso.Topology
The original Topology

Returns
-------
connection : list of nodes, in order of bonding
(start, mid1, mid2, end)
"""
(sort_by_n_connections, small) = _get_sorted_by_n_connections(m)
start = sort_by_n_connections[0]
if sort_by_n_connections[2] in small.neighbors(start):
mid1 = sort_by_n_connections[2]
mid2 = sort_by_n_connections[3]
else:
mid1 = sort_by_n_connections[3]
mid2 = sort_by_n_connections[2]

Check warning on line 604 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L603-L604

Added lines #L603 - L604 were not covered by tests

end = sort_by_n_connections[1]
return (start, mid1, mid2, end)


def _trim_duplicates(all_matches):
"""Remove redundant sub-graph matches.

Is there a better way to do this? Like when we format the subgraphs,
can we impose an ordering so it's easier to eliminate redundant matches?
"""
trimmed_list = IndexedSet()
for match in all_matches:
if (
match
and match not in trimmed_list
and match[::-1] not in trimmed_list
):
trimmed_list.add(match)
return trimmed_list
Binary file added cmeutils/tests/assets/pekk-cg.gsd
Binary file not shown.
4 changes: 4 additions & 0 deletions cmeutils/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def p3ht_gsd(self):
def p3ht_cg_gsd(self):
return path.join(asset_dir, "p3ht-cg.gsd")

@pytest.fixture
def pekk_cg_gsd(self):
return path.join(asset_dir, "pekk-cg.gsd")

@pytest.fixture
def mapping(self):
return np.loadtxt(path.join(asset_dir, "mapping.txt"), dtype=int)
Expand Down
Loading
Loading