From cdc4e3c5f0c4924fe5704f17aed969f644e2d8a3 Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:14:06 -0400 Subject: [PATCH 1/6] Add parallelization --- pysidt/sidt.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index 1a5c039..bf674c1 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -11,6 +11,7 @@ from sklearn import linear_model import scipy.sparse as sp import scipy +from joblib import Parallel, delayed logging.basicConfig(level=logging.INFO) @@ -84,6 +85,7 @@ def __init__( r_site=None, r_morph=None, uncertainty_prepruning=False, + n_jobs=1, ): if nodes is None: nodes = {} @@ -128,10 +130,12 @@ def load(self, nodes): else: self.root = None - def select_node(self): + def select_nodes(self): """ - Picks a node to expand + Picks nodes to expand """ + nodes = [] + for name, node in self.nodes.items(): if len(node.items) <= 1 or node.name in self.skip_nodes: continue @@ -141,9 +145,9 @@ def select_node(self): logging.info("Selected node {}".format(node.name)) logging.info("Node has {} items".format(len(node.items))) - return node - else: - return None + nodes.append(node) + + return nodes def generate_extensions(self, node, recursing=False): """ @@ -204,6 +208,9 @@ def extend_tree_from_node(self, parent): new, comp = split_mols(parent.items, ext) ind = extlist.index(ext) grp, grpc, name, typ, indc = exts[ind] + return grp, grpc, name, new, comp + + def add_extension(self, parent, grp, grpc, name, new, comp): logging.info("Choose extension {}".format(name)) node = Node( @@ -290,11 +297,17 @@ def generate_tree(self, data=None, check_data=True): self.clear_data() self.root.items = data[:] - node = self.select_node() + nodes = self.select_nodes() + + while nodes: + outs = Parallel(n_jobs=self.n_jobs)( + delayed(self.extend_tree_from_node)(node) for node in nodes + ) + + for out, node in zip(outs, nodes): + self.add_extension(node, *out) - while node is not None: - self.extend_tree_from_node(node) - node = self.select_node() + nodes = self.select_nodes() def fit_tree(self, data=None, confidence_level=0.95): """ From a65072cbccb121acef0f45cac248ae12de12e9d0 Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:26:26 -0400 Subject: [PATCH 2/6] update --- pysidt/sidt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index bf674c1..e744e08 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -109,6 +109,7 @@ def __init__( self.r_morph = r_morph self.skip_nodes = [] self.uncertainty_prepruning = uncertainty_prepruning + self.n_jobs = n_jobs if len(nodes) > 0: node = nodes[list(nodes.keys())[0]] @@ -245,8 +246,10 @@ def add_extension(self, parent, grp, grpc, name, new, comp): parent.children.append(nodec) parent.items = [] else: - for mol in new: - parent.items.remove(mol) + new_smis = set([mol.to_smiles() for mol in new]) + for mol in list(parent.items): + if mol.to_smiles() in new_smis: + parent.items.remove(mol) def descend_training_from_top(self, only_specific_match=True): """ From 98c359fa44ea7ed6984339033f6c11c264ef3e8a Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:29:47 -0400 Subject: [PATCH 3/6] Better way to remove mols --- pysidt/sidt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index e744e08..5778474 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -246,10 +246,9 @@ def add_extension(self, parent, grp, grpc, name, new, comp): parent.children.append(nodec) parent.items = [] else: - new_smis = set([mol.to_smiles() for mol in new]) - for mol in list(parent.items): - if mol.to_smiles() in new_smis: - parent.items.remove(mol) + + new_smis = {mol.to_smiles() for mol in new} + parent.items = [mol for mol in parent.items if mol.to_smiles() not in new_smis] def descend_training_from_top(self, only_specific_match=True): """ From bf9b69ce96e1e9f3c86d476267deaa3283ee455c Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:33:34 -0400 Subject: [PATCH 4/6] update --- pysidt/sidt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index 5778474..2dfe023 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -247,8 +247,8 @@ def add_extension(self, parent, grp, grpc, name, new, comp): parent.items = [] else: - new_smis = {mol.to_smiles() for mol in new} - parent.items = [mol for mol in parent.items if mol.to_smiles() not in new_smis] + new_smis = {datum.mol.smiles for datum in new} + parent.items = [datum for datum in parent.items if datum.mol.smiles not in new_smis] def descend_training_from_top(self, only_specific_match=True): """ From ab51bc160ad8732b1e7a4cfbe3e6b46dd4a0fcd1 Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:38:36 -0400 Subject: [PATCH 5/6] Catch None --- pysidt/sidt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index 2dfe023..95814ab 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -204,7 +204,7 @@ def extend_tree_from_node(self, parent): extlist = [ext[0] for ext in exts] if not extlist: self.skip_nodes.append(parent.name) - return + return None ext = self.choose_extension(parent, extlist) new, comp = split_mols(parent.items, ext) ind = extlist.index(ext) @@ -307,6 +307,8 @@ def generate_tree(self, data=None, check_data=True): ) for out, node in zip(outs, nodes): + if out is None: + continue self.add_extension(node, *out) nodes = self.select_nodes() From c6facbe4e5dedb06183b5137f86f7e3fb772b0cf Mon Sep 17 00:00:00 2001 From: Hao-Wei Pang Date: Fri, 12 Jul 2024 14:53:11 -0400 Subject: [PATCH 6/6] update --- pysidt/sidt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysidt/sidt.py b/pysidt/sidt.py index 95814ab..c690f89 100644 --- a/pysidt/sidt.py +++ b/pysidt/sidt.py @@ -203,7 +203,6 @@ def extend_tree_from_node(self, parent): exts = self.generate_extensions(parent) extlist = [ext[0] for ext in exts] if not extlist: - self.skip_nodes.append(parent.name) return None ext = self.choose_extension(parent, extlist) new, comp = split_mols(parent.items, ext) @@ -308,6 +307,7 @@ def generate_tree(self, data=None, check_data=True): for out, node in zip(outs, nodes): if out is None: + self.skip_nodes.append(node.name) continue self.add_extension(node, *out)