Skip to content

Commit

Permalink
add some utility functions
Browse files Browse the repository at this point in the history
  • Loading branch information
klausweinbauer committed Aug 1, 2024
1 parent c23e985 commit 8a772a2
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions aamutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import networkx as nx
import rdkit.Chem as Chem
import rdkit.Chem.rdmolfiles as rdmolfiles
import rdkit.Chem.rdDepictor as rdDepictor


def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph:
Expand All @@ -26,6 +27,16 @@ def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph:
return g


def smiles_to_graph(smiles: str) -> nx.Graph | tuple[nx.Graph, nx.Graph]:
if ">>" in smiles:
smiles_token = smiles.split(">>")
g = mol_to_graph(rdmolfiles.MolFromSmiles(smiles_token[0]))
h = mol_to_graph(rdmolfiles.MolFromSmiles(smiles_token[1]))
return g, h
else:
return mol_to_graph(rdmolfiles.MolFromSmiles(smiles))


def graph_to_mol(
G: nx.Graph, symbol_key="symbol", aam_key="aam", bond_type_key="bond"
) -> Chem.rdchem.Mol:
Expand Down Expand Up @@ -126,3 +137,46 @@ def print_graph(graph):
)
)
)


def its2mol(its: nx.Graph, aam_key="aam", bond_key="bond") -> Chem.rdchem.Mol:
_its = its.copy()
for n in _its.nodes:
_its.nodes[n][aam_key] = n
for u, v in _its.edges():
_its[u][v][bond_key] = 1
return graph_to_mol(_its)


def plot_its(its, ax, bond_key="bond", aam_key="aam", symbol_key="symbol"):
bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"}
mol = its2mol(its, aam_key=aam_key, bond_key=bond_key)

positions = {}
conformer = rdDepictor.Compute2DCoords(mol)
for i, atom in enumerate(mol.GetAtoms()):
aam = atom.GetAtomMapNum()
apos = mol.GetConformer(conformer).GetAtomPosition(i)
positions[aam] = [apos.x, apos.y]

ax.axis("equal")
ax.axis("off")

nx.draw_networkx_edges(its, positions, edge_color="#000000", ax=ax)
nx.draw_networkx_nodes(its, positions, node_color="#FFFFFF", node_size=500, ax=ax)

labels = {n: "{}:{}".format(d[symbol_key], n) for n, d in its.nodes(data=True)}
edge_labels = {}
for u, v, d in its.edges(data=True):
bc1 = d[bond_key][0]
bc2 = d[bond_key][1]
if bc1 == bc2:
continue
if bc1 in bond_char.keys():
bc1 = bond_char[bc1]
if bc2 in bond_char.keys():
bc2 = bond_char[bc2]
edge_labels[(u, v)] = "({},{})".format(bc1, bc2)

nx.draw_networkx_labels(its, positions, labels=labels, ax=ax)
nx.draw_networkx_edge_labels(its, positions, edge_labels=edge_labels, ax=ax)

0 comments on commit 8a772a2

Please sign in to comment.