Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run black and isort during CI to ensure code formatting quality #19

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ jobs:

- name: Install packages for testing
run: |
python -m pip install pytest nbmake
python -m pip install -e .
python -m pip install -e ".[dev]"

- name: Check formatting
run: |
black --check .
isort --check .

- name: Test PySIDT notebooks
run: |
pytest --nbmake IPython/multi_eval_SIDT_example.ipynb
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.egg-info/
__pycache__/
IPython/
IPython/
build/
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,18 @@ requires-python = ">=3.7"
[tool.setuptools]
packages = ["pysidt"]

[project.optional-dependencies]
dev = ["black", "isort", "pytest", "nbmake"]

[project.urls]
Homepage = "https://github.com/zadorlab/PySIDT"
Issues = "https://github.com/zadorlab/PySIDT/issues"

[tool.black]
line-length = 100
skip-magic-trailing-comma = true

[tool.isort]
profile = "black"
line_length = 100
force_sort_within_sections = true
4 changes: 2 additions & 2 deletions pysidt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pysidt.sidt import (
SubgraphIsomorphicDecisionTree,
Datum,
MultiEvalSubgraphIsomorphicDecisionTree,
Node,
Datum,
SubgraphIsomorphicDecisionTree,
read_nodes,
write_nodes,
)
204 changes: 48 additions & 156 deletions pysidt/extensions.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pysidt/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from IPython.display import Image, display
import pydot
import os


def plot_tree(sidt, images=True):
Expand Down
26 changes: 6 additions & 20 deletions pysidt/regularization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from molecule.molecule.atomtype import ATOMTYPES

from pysidt.utils import data_matches_node


Expand Down Expand Up @@ -51,9 +52,7 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
and atm1.atomtype == atm2.atomtype
and len(atm1.bonds) == len(atm2.bonds)
):
bdpairs2 = {
(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()
}
bdpairs2 = {(atm, tuple(bd.order)) for atm, bd in atm2.bonds.items()}
if bdpairs == bdpairs2:
skip = True
indistinguishable.append(i)
Expand All @@ -73,10 +72,7 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
vals = list(set(atyp) & set(atm1.reg_dim_atm[1]))
assert vals != [], "cannot regularize to empty"
if all(
[
set(child.group.atoms[i].atomtype) <= set(vals)
for child in node.children
]
[set(child.group.atoms[i].atomtype) <= set(vals) for child in node.children]
):
if not test:
atm1.atomtype = vals
Expand Down Expand Up @@ -116,11 +112,7 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
if not data_matches_node(node, data):
atm1.radical_electrons = oldvals

if (
not skip
and atm1.reg_dim_site[1] != []
and set(atm1.reg_dim_site[1]) != set(atm1.site)
):
if not skip and atm1.reg_dim_site[1] != [] and set(atm1.reg_dim_site[1]) != set(atm1.site):
if len(atm1.site) == 1:
pass
else:
Expand Down Expand Up @@ -179,17 +171,11 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):
if (
not skip
and atm1.reg_dim_r[1] != []
and (
"inRing" not in atm1.props.keys()
or atm1.reg_dim_r[1][0] != atm1.props["inRing"]
)
and ("inRing" not in atm1.props.keys() or atm1.reg_dim_r[1][0] != atm1.props["inRing"])
):
if "inRing" not in atm1.props.keys():
if all(
[
"inRing" in child.group.atoms[i].props.keys()
for child in node.children
]
["inRing" in child.group.atoms[i].props.keys() for child in node.children]
) and all(
[
child.group.atoms[i].props["inRing"] == atm1.reg_dim_r[1]
Expand Down
75 changes: 24 additions & 51 deletions pysidt/sidt.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
import json
import logging

from molecule.molecule import Group
from molecule.quantity import ScalarQuantity
from pysidt.extensions import split_mols, get_extension_edge
from pysidt.regularization import simple_regularization
from pysidt.decomposition import *
from pysidt.utils import *
import numpy as np
import logging
import json
from sklearn import linear_model
import scipy.sparse as sp
from sklearn import linear_model

from pysidt.decomposition import *
from pysidt.extensions import get_extension_edge, split_mols
from pysidt.regularization import simple_regularization
from pysidt.utils import *

logging.basicConfig(level=logging.INFO)


class Node:
def __init__(
self,
group=None,
items=None,
rule=None,
parent=None,
children=None,
name=None,
depth=None,
self, group=None, items=None, rule=None, parent=None, children=None, name=None, depth=None
):
if items is None:
items = []
Expand Down Expand Up @@ -161,9 +156,9 @@ def choose_extension(self, node, exts):
minext = None
for ext in exts:
new, comp = split_mols(node.items, ext)
val = np.std([x.value for x in new]) * len(new) + np.std(
[x.value for x in comp]
) * len(comp)
val = np.std([x.value for x in new]) * len(new) + np.std([x.value for x in comp]) * len(
comp
)
if val < minval:
minval = val
minext = ext
Expand Down Expand Up @@ -324,7 +319,7 @@ def to_dict(obj):
try:
json.dumps(val)
out_dict[attr] = val
except:
except BaseException:
if isinstance(val, ScalarQuantity):
out_dict[attr] = {
"class": val.__class__.__name__,
Expand All @@ -341,7 +336,7 @@ def to_dict(obj):

def from_dict(d, class_dict=None):
"""construct objects from dictionary

Args:
d (dict): dictionary describing object, particularly containing a value
associated with "class" identifying a string of the class of the object
Expand Down Expand Up @@ -382,10 +377,8 @@ def write_nodes(tree, file):
) # will work on all rmgmolecule objects, new objects need this method implemented
try:
json.dumps(rule)
except:
raise ValueError(
f"Could not serialize object {node.rule.__class__.__name__}"
)
except BaseException:
raise ValueError(f"Could not serialize object {node.rule.__class__.__name__}")

nodesdict[node.name] = {
"group": node.group.to_adjacency_list(),
Expand Down Expand Up @@ -620,13 +613,9 @@ def choose_extension(self, node, exts):
for i, datum in enumerate(self.datums):
dy = self.data_delta[i] / len(self.mol_node_maps[datum]["mols"])
for j, d in enumerate(self.mol_node_maps[datum]["mols"]):
v = self.node_uncertainties[
self.mol_node_maps[datum]["nodes"][j].name
]
v = self.node_uncertainties[self.mol_node_maps[datum]["nodes"][j].name]
s = sum(
self.node_uncertainties[
self.mol_node_maps[datum]["nodes"][k].name
]
self.node_uncertainties[self.mol_node_maps[datum]["nodes"][k].name]
for k in range(len(self.mol_node_maps[datum]["nodes"]))
)
if any(d is x for x in new):
Expand Down Expand Up @@ -666,10 +655,7 @@ def setup_data(self, data, check_data=False):
self.mol_node_maps = dict()
for datum in self.datums:
decomp = self.decomposition(datum.mol)
self.mol_node_maps[datum] = {
"mols": decomp,
"nodes": [self.root for d in decomp],
}
self.mol_node_maps[datum] = {"mols": decomp, "nodes": [self.root for d in decomp]}

if check_data:
for i, datum in enumerate(self.datums):
Expand Down Expand Up @@ -721,9 +707,7 @@ def generate_tree(
if len(self.nodes) > max_nodes:
break
self.new_nodes = []
num = int(
max(1, np.round(self.fract_nodes_expand_per_iter * len(self.nodes)))
)
num = int(max(1, np.round(self.fract_nodes_expand_per_iter * len(self.nodes))))
nodes = self.select_nodes(num=num)
if not nodes:
break
Expand Down Expand Up @@ -796,11 +780,7 @@ def fit_rule(self, alpha=0.1):
node = node.parent

clf = linear_model.Lasso(
alpha=alpha,
fit_intercept=False,
tol=1e-4,
max_iter=1000000000,
selection="random",
alpha=alpha, fit_intercept=False, tol=1e-4, max_iter=1000000000, selection="random"
)

lasso = clf.fit(A, y)
Expand Down Expand Up @@ -863,9 +843,7 @@ def estimate_uncertainty(self):
{node.name: node_uncertainties[i] for i, node in enumerate(nodes)}
)
else:
self.node_uncertainties.update(
{node.name: 1.0 for i, node in enumerate(nodes)}
)
self.node_uncertainties.update({node.name: 1.0 for i, node in enumerate(nodes)})

def assign_depths(self):
root = self.root
Expand Down Expand Up @@ -924,12 +902,7 @@ def regularize(self, data=None, check_data=True):
self.descend_training_from_top(only_specific_match=False)

simple_regularization(
self.nodes["Root"],
self.r,
self.r_bonds,
self.r_un,
self.r_site,
self.r_morph,
self.nodes["Root"], self.r, self.r_bonds, self.r_un, self.r_site, self.r_morph
)


Expand Down
4 changes: 1 addition & 3 deletions pysidt/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
def data_matches_node(node, data):
for m in data:
if not m.is_subgraph_isomorphic(
node.group, generate_initial_map=True, save_order=True
):
if not m.is_subgraph_isomorphic(node.group, generate_initial_map=True, save_order=True):
return False
else:
return True