From c3f063246c08de7c3e1bcd206c85a60e3610e7d9 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Mon, 26 Aug 2024 14:03:28 -0700 Subject: [PATCH] respect atom map not just atom order in regularization before regularization assumed that parent and child groups had the same atom ordering forward extension generation respects this however: manually created groups and reverse extension generation do not now we use subgraph isomorphism to construct maps when multiple isomorphisms are available we choose the one that is closest to matching the atom ordering between parent and child --- pysidt/regularization.py | 86 +++++++++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/pysidt/regularization.py b/pysidt/regularization.py index fc3dfb3..acd6cba 100644 --- a/pysidt/regularization.py +++ b/pysidt/regularization.py @@ -1,6 +1,6 @@ from molecule.molecule.atomtype import ATOMTYPES from pysidt.utils import data_matches_node - +import itertools def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): """ @@ -20,9 +20,57 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): to generate the tree and test=False is ok if the cascade algorithm wasn't used. """ + child_map = [] for child in node.children: + #recursively regularize child simple_regularization(child, Rx, Rbonds, Run, Rsite, Rmorph) - + + #generate atom map to children + if node.group is not None: + keys = [] + atms = [] + initial_map = dict() + for atom in child.group.atoms: + if atom.label and atom.label != '': + L = [a for a in node.group.atoms if a.label == atom.label] + if L == []: + return False + elif len(L) == 1: + initial_map[atom] = L[0] + else: + keys.append(atom) + atms.append(L) + if atms: + for atmlist in itertools.product(*atms): + if len(set(atmlist)) != len(atmlist): + # skip entries that map multiple graph atoms to the same subgraph atom + continue + for i, key in enumerate(keys): + initial_map[key] = atmlist[i] + if child.group.is_mapping_valid(node.group, initial_map, equivalent=False): + isos = child.group.find_subgraph_isomorphism(node.group, initial_map, save_order=True) + def isomorph_sort_key(d): + v = 0 + for k,v in d.items(): + if node.group.index(v) == child.group.index(k): + v += 1 + return -v + child_map.append({v:k for k,v in sorted(isos,key=isomorph_sort_key)[0].items()}) + else: + raise ValueError(f"Could not find valid mapping between parent {node.name} and child {child.name}") + else: + isos = child.group.find_subgraph_isomorphism(node.group, initial_map, save_order=True) + def isomorph_sort_key(d): + v = 0 + for k,v in d.items(): + if node.group.index(v) == child.group.index(k): + v += 1 + return -v + child_map.append({v:k for k,v in sorted(isos,key=isomorph_sort_key)[0].items()}) + + if node.group is None: #skip None nodes + return + grp = node.group data = node.items @@ -77,8 +125,8 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): assert vals != [], "cannot regularize to empty" if all( [ - set(child.group.atoms[i].atomtype) <= set(vals) - for child in node.children + set(child_map[q][node.group.atoms[i]].atomtype) <= set(vals) + for q,child in enumerate(node.children) ] ): if not test: @@ -105,10 +153,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): if all( [ - set(child.group.atoms[i].radical_electrons) <= set(vals) - if child.group.atoms[i].radical_electrons != [] + set(child_map[q][node.group.atoms[i]].radical_electrons) <= set(vals) + if child_map[q][node.group.atoms[i]].radical_electrons != [] else False - for child in node.children + for q,child in enumerate(node.children) ] ): if not test: @@ -135,10 +183,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): if all( [ - set(child.group.atoms[i].site) <= set(vals) - if child.group.atoms[i].site != [] + set(child_map[q][node.group.atoms[i]].site) <= set(vals) + if child_map[q][node.group.atoms[i]].site != [] else False - for child in node.children + for q,child in enumerate(node.children) ] ): if not test: @@ -165,10 +213,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): if all( [ - set(child.group.atoms[i].morphology) <= set(vals) - if child.group.atoms[i].morphology != [] + set(child_map[q][node.group.atoms[i]].morphology) <= set(vals) + if child_map[q][node.group.atoms[i]].morphology != [] else False - for child in node.children + for q,child in enumerate(node.children) ] ): if not test: @@ -190,13 +238,13 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): if "inRing" not in atm1.props.keys(): if all( [ - "inRing" in child.group.atoms[i].props.keys() - for child in node.children + "inRing" in child_map[q][node.group.atoms[i]].props.keys() + for q,child in enumerate(node.children) ] ) and all( [ - child.group.atoms[i].props["inRing"] == atm1.reg_dim_r[1] - for child in node.children + child_map[q][node.group.atoms[i]].props["inRing"] == atm1.reg_dim_r[1] + for q,child in enumerate(node.children) ] ): if not test: @@ -226,11 +274,11 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True): [ set( child.group.get_bond( - child.group.atoms[i], child.group.atoms[j] + child_map[q][atm1], child_map[q][atm2] ).order ) <= set(vals) - for child in node.children + for q,child in enumerate(node.children) ] ): if not test: