Skip to content

Commit

Permalink
add function to identify connections in a snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
marjanalbooyeh committed Jan 18, 2024
1 parent 13c6fdd commit 8adc1c9
Showing 1 changed file with 235 additions and 0 deletions.
235 changes: 235 additions & 0 deletions cmeutils/gsd_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings
from tempfile import NamedTemporaryFile

import freud
import gsd.hoomd
import hoomd
import networkx as nx
import numpy as np
from boltons.setutils import IndexedSet

from cmeutils.geometry import moit

Expand Down Expand Up @@ -374,3 +377,235 @@ def xml_to_gsd(xmlfile, gsdfile):
snap.bonds.group = bonds
newt.append(snap)
print(f"XML data written to {gsdfile}")


def identify_snapshot_connections(snapshot):
"""Identify angle and dihedral connections in a snapshot from bonds.
Parameters
----------
snapshot : gsd.hoomd.Frame
The snapshot to read in.
Returns
-------
gsd.hoomd.Frame
The snapshot with angle and dihedral information added.
"""
if snapshot.bonds.N == 0:
warnings.warn(

Check warning on line 396 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L395-L396

Added lines #L395 - L396 were not covered by tests
"No bonds found in snapshot, hence, no angles or "
"dihedrals will be identified."
)
return snapshot
bond_groups = snapshot.bonds.group
connection_matches = _find_connections(bond_groups)

Check warning on line 402 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L400-L402

Added lines #L400 - L402 were not covered by tests

if connection_matches["angles"]:
_fill_connection_info(

Check warning on line 405 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L404-L405

Added lines #L404 - L405 were not covered by tests
snapshot=snapshot,
connections=connection_matches["angles"],
type_="angles",
)
if connection_matches["dihedrals"]:
_fill_connection_info(

Check warning on line 411 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L410-L411

Added lines #L410 - L411 were not covered by tests
snapshot=snapshot,
connections=connection_matches["dihedrals"],
type_="dihedrals",
)
return snapshot

Check warning on line 416 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L416

Added line #L416 was not covered by tests


def _fill_connection_info(snapshot, connections, type_):
p_types = snapshot.particles.types
p_typeid = snapshot.particles.typeid
_connection_types = []
_connection_typeid = []
for conn in connections:
type = "-".join([p_types[p_typeid[i]] for i in conn])
types_inv = type[::-1]

Check warning on line 426 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L420-L426

Added lines #L420 - L426 were not covered by tests
# check if type not in angle_types and types_inv not in angle_types:
if type not in _connection_types or types_inv not in _connection_types:
_connection_types.append(type)
_connection_typeid.append(

Check warning on line 430 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L428-L430

Added lines #L428 - L430 were not covered by tests
max(_connection_typeid) + 1 if _connection_typeid else 0
)
else:
_connection_typeid.append(_connection_types.index(type))

Check warning on line 434 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L434

Added line #L434 was not covered by tests

if type_ == "angles":
snapshot.angles.N = len(connections)
snapshot.angles.M = len(_connection_types)
snapshot.angles.group = connections
snapshot.angles.types = _connection_types
snapshot.angles.typeid = _connection_typeid
elif type_ == "dihedrals":
snapshot.dihedrals.N = len(connections)
snapshot.dihedrals.M = len(_connection_types)
snapshot.dihedrals.group = connections
snapshot.dihedrals.types = _connection_types
snapshot.dihedrals.typeid = _connection_typeid

Check warning on line 447 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L436-L447

Added lines #L436 - L447 were not covered by tests


# The following functions are obtained from gmso/utils/connectivity.py with
# minor modifications.
def _find_connections(bonds):
"""Identify all possible connections within a topology."""
compound = nx.Graph()

Check warning on line 454 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L454

Added line #L454 was not covered by tests

for b in bonds:
compound.add_edge(b[0], b[1])

Check warning on line 457 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L456-L457

Added lines #L456 - L457 were not covered by tests

compound_line_graph = nx.line_graph(compound)

Check warning on line 459 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L459

Added line #L459 was not covered by tests

angle_matches = _detect_connections(compound_line_graph, type_="angle")
dihedral_matches = _detect_connections(

Check warning on line 462 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L461-L462

Added lines #L461 - L462 were not covered by tests
compound_line_graph, type_="dihedral"
)

return {

Check warning on line 466 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L466

Added line #L466 was not covered by tests
"angles": angle_matches,
"dihedrals": dihedral_matches,
}


def _detect_connections(compound_line_graph, type_="angle"):
EDGES = {

Check warning on line 473 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L473

Added line #L473 was not covered by tests
"angle": ((0, 1),),
"dihedral": ((0, 1), (1, 2)),
}

connection = nx.Graph()
for edge in EDGES[type_]:
assert len(edge) == 2, "Edges should be of length 2"
connection.add_edge(edge[0], edge[1])

Check warning on line 481 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L478-L481

Added lines #L478 - L481 were not covered by tests

matcher = nx.algorithms.isomorphism.GraphMatcher(

Check warning on line 483 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L483

Added line #L483 was not covered by tests
compound_line_graph, connection
)

formatter_fns = {

Check warning on line 487 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L487

Added line #L487 was not covered by tests
"angle": _format_subgraph_angle,
"dihedral": _format_subgraph_dihedral,
}

conn_matches = IndexedSet()
for m in matcher.subgraph_isomorphisms_iter():
new_connection = formatter_fns[type_](m)
conn_matches.add(new_connection)
if conn_matches:
conn_matches = _trim_duplicates(conn_matches)

Check warning on line 497 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L492-L497

Added lines #L492 - L497 were not covered by tests

# Do more sorting of individual connection
sorted_conn_matches = list()
for match in conn_matches:
if match[0] < match[-1]:
sorted_conn = match

Check warning on line 503 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L500-L503

Added lines #L500 - L503 were not covered by tests
else:
sorted_conn = match[::-1]
sorted_conn_matches.append(sorted_conn)

Check warning on line 506 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L505-L506

Added lines #L505 - L506 were not covered by tests

# Final sorting the whole list
if type_ == "angle":
return sorted(

Check warning on line 510 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L509-L510

Added lines #L509 - L510 were not covered by tests
sorted_conn_matches,
key=lambda angle: (
angle[1],
angle[0],
angle[2],
),
)
elif type_ == "dihedral":
return sorted(

Check warning on line 519 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L518-L519

Added lines #L518 - L519 were not covered by tests
sorted_conn_matches,
key=lambda dihedral: (
dihedral[1],
dihedral[2],
dihedral[0],
dihedral[3],
),
)


def _get_sorted_by_n_connections(m):
"""Return sorted by n connections for the matching graph."""
small = nx.Graph()
for k, v in m.items():
small.add_edge(k[0], k[1])
return sorted(small.adj, key=lambda x: len(small[x])), small

Check warning on line 535 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L532-L535

Added lines #L532 - L535 were not covered by tests


def _format_subgraph_angle(m):
"""Format the angle subgraph.
Since we are matching compound line graphs,
back out the actual nodes, not just the edges
Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)
Returns
-------
connection : list of nodes, in order of bonding
(start, middle, end)
"""
(sort_by_n_connections, _) = _get_sorted_by_n_connections(m)
ends = sorted([sort_by_n_connections[0], sort_by_n_connections[1]])
middle = sort_by_n_connections[2]
return [

Check warning on line 558 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L555-L558

Added lines #L555 - L558 were not covered by tests
ends[0],
middle,
ends[1],
]


def _format_subgraph_dihedral(m):
"""Format the dihedral subgraph.
Since we are matching compound line graphs,
back out the actual nodes, not just the edges
Parameters
----------
m : dict
keys are the compound line graph nodes
Values are the sub-graph matches (to the angle, dihedral, or improper)
top : gmso.Topology
The original Topology
Returns
-------
connection : list of nodes, in order of bonding
(start, mid1, mid2, end)
"""
(sort_by_n_connections, small) = _get_sorted_by_n_connections(m)
start = sort_by_n_connections[0]
if sort_by_n_connections[2] in small.neighbors(start):
mid1 = sort_by_n_connections[2]
mid2 = sort_by_n_connections[3]

Check warning on line 588 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L584-L588

Added lines #L584 - L588 were not covered by tests
else:
mid1 = sort_by_n_connections[3]
mid2 = sort_by_n_connections[2]

Check warning on line 591 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L590-L591

Added lines #L590 - L591 were not covered by tests

end = sort_by_n_connections[1]
return [start, mid1, mid2, end]

Check warning on line 594 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L593-L594

Added lines #L593 - L594 were not covered by tests


def _trim_duplicates(all_matches):
"""Remove redundant sub-graph matches.
Is there a better way to do this? Like when we format the subgraphs,
can we impose an ordering so it's easier to eliminate redundant matches?
"""
trimmed_list = IndexedSet()
for match in all_matches:
if (

Check warning on line 605 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L603-L605

Added lines #L603 - L605 were not covered by tests
match
and match not in trimmed_list
and match[::-1] not in trimmed_list
):
trimmed_list.add(match)
return trimmed_list

Check warning on line 611 in cmeutils/gsd_utils.py

View check run for this annotation

Codecov / codecov/patch

cmeutils/gsd_utils.py#L610-L611

Added lines #L610 - L611 were not covered by tests

0 comments on commit 8adc1c9

Please sign in to comment.