From 8adc1c99cee3ac8fe0d5800acfd53e8b7583f9b7 Mon Sep 17 00:00:00 2001 From: marjanalbouye Date: Thu, 18 Jan 2024 10:26:52 -0700 Subject: [PATCH] add function to identify connections in a snapshot --- cmeutils/gsd_utils.py | 235 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 3f843a5..86b1734 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -1,9 +1,12 @@ +import warnings from tempfile import NamedTemporaryFile import freud import gsd.hoomd import hoomd +import networkx as nx import numpy as np +from boltons.setutils import IndexedSet from cmeutils.geometry import moit @@ -374,3 +377,235 @@ def xml_to_gsd(xmlfile, gsdfile): snap.bonds.group = bonds newt.append(snap) print(f"XML data written to {gsdfile}") + + +def identify_snapshot_connections(snapshot): + """Identify angle and dihedral connections in a snapshot from bonds. + + Parameters + ---------- + snapshot : gsd.hoomd.Frame + The snapshot to read in. + + Returns + ------- + gsd.hoomd.Frame + The snapshot with angle and dihedral information added. + """ + if snapshot.bonds.N == 0: + warnings.warn( + "No bonds found in snapshot, hence, no angles or " + "dihedrals will be identified." + ) + return snapshot + bond_groups = snapshot.bonds.group + connection_matches = _find_connections(bond_groups) + + if connection_matches["angles"]: + _fill_connection_info( + snapshot=snapshot, + connections=connection_matches["angles"], + type_="angles", + ) + if connection_matches["dihedrals"]: + _fill_connection_info( + snapshot=snapshot, + connections=connection_matches["dihedrals"], + type_="dihedrals", + ) + return snapshot + + +def _fill_connection_info(snapshot, connections, type_): + p_types = snapshot.particles.types + p_typeid = snapshot.particles.typeid + _connection_types = [] + _connection_typeid = [] + for conn in connections: + type = "-".join([p_types[p_typeid[i]] for i in conn]) + types_inv = type[::-1] + # check if type not in angle_types and types_inv not in angle_types: + if type not in _connection_types or types_inv not in _connection_types: + _connection_types.append(type) + _connection_typeid.append( + max(_connection_typeid) + 1 if _connection_typeid else 0 + ) + else: + _connection_typeid.append(_connection_types.index(type)) + + if type_ == "angles": + snapshot.angles.N = len(connections) + snapshot.angles.M = len(_connection_types) + snapshot.angles.group = connections + snapshot.angles.types = _connection_types + snapshot.angles.typeid = _connection_typeid + elif type_ == "dihedrals": + snapshot.dihedrals.N = len(connections) + snapshot.dihedrals.M = len(_connection_types) + snapshot.dihedrals.group = connections + snapshot.dihedrals.types = _connection_types + snapshot.dihedrals.typeid = _connection_typeid + + +# The following functions are obtained from gmso/utils/connectivity.py with +# minor modifications. +def _find_connections(bonds): + """Identify all possible connections within a topology.""" + compound = nx.Graph() + + for b in bonds: + compound.add_edge(b[0], b[1]) + + compound_line_graph = nx.line_graph(compound) + + angle_matches = _detect_connections(compound_line_graph, type_="angle") + dihedral_matches = _detect_connections( + compound_line_graph, type_="dihedral" + ) + + return { + "angles": angle_matches, + "dihedrals": dihedral_matches, + } + + +def _detect_connections(compound_line_graph, type_="angle"): + EDGES = { + "angle": ((0, 1),), + "dihedral": ((0, 1), (1, 2)), + } + + connection = nx.Graph() + for edge in EDGES[type_]: + assert len(edge) == 2, "Edges should be of length 2" + connection.add_edge(edge[0], edge[1]) + + matcher = nx.algorithms.isomorphism.GraphMatcher( + compound_line_graph, connection + ) + + formatter_fns = { + "angle": _format_subgraph_angle, + "dihedral": _format_subgraph_dihedral, + } + + conn_matches = IndexedSet() + for m in matcher.subgraph_isomorphisms_iter(): + new_connection = formatter_fns[type_](m) + conn_matches.add(new_connection) + if conn_matches: + conn_matches = _trim_duplicates(conn_matches) + + # Do more sorting of individual connection + sorted_conn_matches = list() + for match in conn_matches: + if match[0] < match[-1]: + sorted_conn = match + else: + sorted_conn = match[::-1] + sorted_conn_matches.append(sorted_conn) + + # Final sorting the whole list + if type_ == "angle": + return sorted( + sorted_conn_matches, + key=lambda angle: ( + angle[1], + angle[0], + angle[2], + ), + ) + elif type_ == "dihedral": + return sorted( + sorted_conn_matches, + key=lambda dihedral: ( + dihedral[1], + dihedral[2], + dihedral[0], + dihedral[3], + ), + ) + + +def _get_sorted_by_n_connections(m): + """Return sorted by n connections for the matching graph.""" + small = nx.Graph() + for k, v in m.items(): + small.add_edge(k[0], k[1]) + return sorted(small.adj, key=lambda x: len(small[x])), small + + +def _format_subgraph_angle(m): + """Format the angle subgraph. + + Since we are matching compound line graphs, + back out the actual nodes, not just the edges + + Parameters + ---------- + m : dict + keys are the compound line graph nodes + Values are the sub-graph matches (to the angle, dihedral, or improper) + + Returns + ------- + connection : list of nodes, in order of bonding + (start, middle, end) + """ + (sort_by_n_connections, _) = _get_sorted_by_n_connections(m) + ends = sorted([sort_by_n_connections[0], sort_by_n_connections[1]]) + middle = sort_by_n_connections[2] + return [ + ends[0], + middle, + ends[1], + ] + + +def _format_subgraph_dihedral(m): + """Format the dihedral subgraph. + + Since we are matching compound line graphs, + back out the actual nodes, not just the edges + + Parameters + ---------- + m : dict + keys are the compound line graph nodes + Values are the sub-graph matches (to the angle, dihedral, or improper) + top : gmso.Topology + The original Topology + + Returns + ------- + connection : list of nodes, in order of bonding + (start, mid1, mid2, end) + """ + (sort_by_n_connections, small) = _get_sorted_by_n_connections(m) + start = sort_by_n_connections[0] + if sort_by_n_connections[2] in small.neighbors(start): + mid1 = sort_by_n_connections[2] + mid2 = sort_by_n_connections[3] + else: + mid1 = sort_by_n_connections[3] + mid2 = sort_by_n_connections[2] + + end = sort_by_n_connections[1] + return [start, mid1, mid2, end] + + +def _trim_duplicates(all_matches): + """Remove redundant sub-graph matches. + + Is there a better way to do this? Like when we format the subgraphs, + can we impose an ordering so it's easier to eliminate redundant matches? + """ + trimmed_list = IndexedSet() + for match in all_matches: + if ( + match + and match not in trimmed_list + and match[::-1] not in trimmed_list + ): + trimmed_list.add(match) + return trimmed_list