Skip to content

Commit

Permalink
respect atom map not just atom order in regularization
Browse files Browse the repository at this point in the history
before regularization assumed that parent and child groups had the same atom ordering
forward extension generation respects this however:
manually created groups and reverse extension generation do not

now we use subgraph isomorphism to construct maps
when multiple isomorphisms are available we choose the one that is closest to matching the atom ordering between parent and child
  • Loading branch information
mjohnson541 committed Aug 26, 2024
1 parent c0b87a4 commit c3f0632
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions pysidt/regularization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from molecule.molecule.atomtype import ATOMTYPES
from pysidt.utils import data_matches_node

import itertools

def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
"""
Expand All @@ -20,9 +20,57 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
to generate the tree and test=False is ok if the cascade algorithm
wasn't used.
"""
child_map = []
for child in node.children:
#recursively regularize child
simple_regularization(child, Rx, Rbonds, Run, Rsite, Rmorph)


#generate atom map to children
if node.group is not None:
keys = []
atms = []
initial_map = dict()
for atom in child.group.atoms:
if atom.label and atom.label != '':
L = [a for a in node.group.atoms if a.label == atom.label]
if L == []:
return False
elif len(L) == 1:
initial_map[atom] = L[0]
else:
keys.append(atom)
atms.append(L)
if atms:
for atmlist in itertools.product(*atms):
if len(set(atmlist)) != len(atmlist):
# skip entries that map multiple graph atoms to the same subgraph atom
continue
for i, key in enumerate(keys):
initial_map[key] = atmlist[i]
if child.group.is_mapping_valid(node.group, initial_map, equivalent=False):
isos = child.group.find_subgraph_isomorphism(node.group, initial_map, save_order=True)
def isomorph_sort_key(d):
v = 0
for k,v in d.items():
if node.group.index(v) == child.group.index(k):
v += 1
return -v
child_map.append({v:k for k,v in sorted(isos,key=isomorph_sort_key)[0].items()})
else:
raise ValueError(f"Could not find valid mapping between parent {node.name} and child {child.name}")
else:
isos = child.group.find_subgraph_isomorphism(node.group, initial_map, save_order=True)
def isomorph_sort_key(d):
v = 0
for k,v in d.items():
if node.group.index(v) == child.group.index(k):
v += 1
return -v
child_map.append({v:k for k,v in sorted(isos,key=isomorph_sort_key)[0].items()})

if node.group is None: #skip None nodes
return

grp = node.group
data = node.items

Expand Down Expand Up @@ -77,8 +125,8 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
assert vals != [], "cannot regularize to empty"
if all(
[
set(child.group.atoms[i].atomtype) <= set(vals)
for child in node.children
set(child_map[q][node.group.atoms[i]].atomtype) <= set(vals)
for q,child in enumerate(node.children)
]
):
if not test:
Expand All @@ -105,10 +153,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):

if all(
[
set(child.group.atoms[i].radical_electrons) <= set(vals)
if child.group.atoms[i].radical_electrons != []
set(child_map[q][node.group.atoms[i]].radical_electrons) <= set(vals)
if child_map[q][node.group.atoms[i]].radical_electrons != []
else False
for child in node.children
for q,child in enumerate(node.children)
]
):
if not test:
Expand All @@ -135,10 +183,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):

if all(
[
set(child.group.atoms[i].site) <= set(vals)
if child.group.atoms[i].site != []
set(child_map[q][node.group.atoms[i]].site) <= set(vals)
if child_map[q][node.group.atoms[i]].site != []
else False
for child in node.children
for q,child in enumerate(node.children)
]
):
if not test:
Expand All @@ -165,10 +213,10 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):

if all(
[
set(child.group.atoms[i].morphology) <= set(vals)
if child.group.atoms[i].morphology != []
set(child_map[q][node.group.atoms[i]].morphology) <= set(vals)
if child_map[q][node.group.atoms[i]].morphology != []
else False
for child in node.children
for q,child in enumerate(node.children)
]
):
if not test:
Expand All @@ -190,13 +238,13 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
if "inRing" not in atm1.props.keys():
if all(
[
"inRing" in child.group.atoms[i].props.keys()
for child in node.children
"inRing" in child_map[q][node.group.atoms[i]].props.keys()
for q,child in enumerate(node.children)
]
) and all(
[
child.group.atoms[i].props["inRing"] == atm1.reg_dim_r[1]
for child in node.children
child_map[q][node.group.atoms[i]].props["inRing"] == atm1.reg_dim_r[1]
for q,child in enumerate(node.children)
]
):
if not test:
Expand Down Expand Up @@ -226,11 +274,11 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
[
set(
child.group.get_bond(
child.group.atoms[i], child.group.atoms[j]
child_map[q][atm1], child_map[q][atm2]
).order
)
<= set(vals)
for child in node.children
for q,child in enumerate(node.children)
]
):
if not test:
Expand Down

0 comments on commit c3f0632

Please sign in to comment.