Skip to content

Commit

Permalink
made small corrections.
Browse files Browse the repository at this point in the history
- nodes -> res_idx
- now specify name of contact map file to be written with changes to -go-file-write
  • Loading branch information
csbrasnett committed Dec 20, 2024
1 parent 6c6f935 commit 7880986
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 56 deletions.
41 changes: 15 additions & 26 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def entry():
"-go",
dest="go",
nargs='?',
const=None,
const=True,
type=Path,
help="Use Martini Go model. Accepts either an input file from the server, "
"or just provide the flag to calculate as part of Martinize."
Expand Down Expand Up @@ -584,8 +584,9 @@ def entry():
go_group.add_argument(
"-go-write-file",
dest="go_write_file",
action="store_true",
default=False,
nargs='?',
const=True,
type=Path,
help=("Write out contact map to file if calculating as part of Martinize2.")
)

Expand Down Expand Up @@ -833,21 +834,6 @@ def entry():
"be used together."
)

"""
Sort out the use of the go model:
go = True: apply go model
go_file = str: parse contact map
bool(go_file) = False: no go
"""
go = False
go_file = None
if args.go is None:
go = True
else:
go_file = args.go
go = True

if args.to_ff.startswith("elnedyn"):
# FIXME: This type of thing should be added to the FF itself.
LOGGER.info(
Expand Down Expand Up @@ -989,15 +975,19 @@ def entry():
elif args.cystein_bridge != "auto":
vermouth.AddCysteinBridgesThreshold(args.cystein_bridge).run_system(system)

if go:
if args.go:
system = vermouth.MergeAllMolecules().run_system(system)
# need this here because have to get contact map at atomistic resolution
if go_file is None:
LOGGER.info("Generating Go model contact map.", type="step")
GenerateContactMap(write_file=args.go_write_file).run_system(system)
else:
if isinstance(args.go, str):
LOGGER.info("Reading Go model contact map.", type="step")
read_go_map(system=system, path=go_file)
read_go_map(system=system, path=args.go)
else:
LOGGER.info("Generating Go model contact map.", type="step")
if isinstance(args.go_write_file, Path):
go_file_path = str(args.go_write_file)
else:
go_file_path = "contact_map_martinize.out"
GenerateContactMap(write_file=go_file_path).run_system(system)


# Run martinize on the system.
Expand All @@ -1021,8 +1011,7 @@ def entry():
vermouth.ApplyPosres(node_selector, args.posres_fc).run_system(system)

# Generate the Go model if required

if go:
if system.go_params["go_map"]:
go_name_prefix = args.molname
LOGGER.info("Generating the Go model.", type="step")
GoPipeline.run_system(system,
Expand Down
61 changes: 31 additions & 30 deletions vermouth/rcsu/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..graph_utils import make_residue_graph
from itertools import product
from vermouth.file_writer import deferred_open
from pathlib import Path
from collections import defaultdict

# BOND TYPE
# Types of contacts:
Expand Down Expand Up @@ -355,11 +355,10 @@ def atom2res(arrin, nresidues, atom_map, norm=False):
'''

out = np.zeros((nresidues, nresidues))
for res_idx, res_jdx in product(np.arange(nresidues), np.arange(nresidues)):
atom_idxs = np.array(atom_map[res_idx])
atom_jdxs = np.array(atom_map[res_jdx])
value = arrin[atom_idxs,
atom_jdxs[:, np.newaxis]].sum()
for res_idx, res_jdx in product(atom_map.keys(), atom_map.keys()):
atom_idxs = atom_map[res_idx]
atom_jdxs = atom_map[res_jdx][:, np.newaxis]
value = arrin[atom_idxs, atom_jdxs].sum()
out[res_idx, res_jdx] = value

if norm:
Expand All @@ -383,16 +382,16 @@ def _contact_info(molecule):
vdw_list = []
atypes = []
res_serial = []
nodes = []
res_idx = []
for residue in G.nodes:
# we only need these for writing at the end
resnames.append(G.nodes[residue]['resname'])
resids.append(G.nodes[residue]['resid'])
chains.append(G.nodes[residue]['chain'])
nodes.append(G.nodes[residue]['_res_serial'])
res_idx.append(G.nodes[residue]['_res_serial'])
subgraph = G.nodes[residue]['graph']

for atom in sorted(G.nodes[residue]['graph'].nodes):
for atom in sorted(subgraph.nodes):
position = subgraph.nodes[atom].get('position', [np.nan]*3)
if np.isfinite(position).all():
res_serial.append(subgraph.nodes[atom]['_res_serial'])
Expand All @@ -404,8 +403,8 @@ def _contact_info(molecule):
atypes.append(_get_atype(subgraph.nodes[atom]['resname'],
subgraph.nodes[atom]['atomname']))

if subgraph.nodes[atom]['atomname'] == 'CA':
ca_pos.append(subgraph.nodes[atom]['position'])
if subgraph.nodes[atom]['atomname'] == 'CA':
ca_pos.append(subgraph.nodes[atom]['position'])


vdw_list = np.array(vdw_list)
Expand All @@ -416,14 +415,14 @@ def _contact_info(molecule):
resids = np.array(resids)
chains = np.array(chains)
resnames = np.array(resnames)
nodes = np.array(nodes)
res_idx = np.array(res_idx)

# 2) find the number of residues that we have
nresidues = len(G)

return vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, G
return vdw_list, atypes, coords, res_serial, resids, chains, resnames, res_idx, ca_pos, nresidues, G

def _calculate_overlap(coords_tree, vdw_list, natoms, vdw_max, alpha):
def _calculate_overlap(coords_tree, vdw_list, natoms, vdw_max, alpha=1.24):
"""
Find enlarged (OV) overlap contacts
Expand All @@ -446,7 +445,7 @@ def _calculate_overlap(coords_tree, vdw_list, natoms, vdw_max, alpha):
over[idx, jdx] = 1
return over

def _calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max, water_radius):
def _calculate_csu(coords, vdw_list, fiba, fibb, natoms, coords_tree, vdw_max, water_radius=2.80):
"""
Calculate contacts of structural units (CSU)
Expand Down Expand Up @@ -579,9 +578,11 @@ def _calculate_contacts(vdw_list, atypes, coords, res_serial, nresidues):
# find the types of contacts we have
contactcounter_1, stabilisercounter_1, destabilisercounter_1 = _contact_types(hit_results, natoms, atypes)

atom_map = {}
for i in range(nresidues):
atom_map[i] = np.where(res_serial == i)[0]
atom_map = defaultdict(list)
for atom_idx, res_idx in enumerate(res_serial):
atom_map[res_idx].append(atom_idx)
for key, value in atom_map.items():
atom_map[key] = np.array(value)

# transform the resolution between atoms and residues
overlapcounter_2 = atom2res(over, nresidues, atom_map, norm=True)
Expand All @@ -592,7 +593,7 @@ def _calculate_contacts(vdw_list, atypes, coords, res_serial, nresidues):
return overlapcounter_2, contactcounter_2, stabilisercounter_2, destabilisercounter_2


def _get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nodes, G):
def _get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, res_idx, G):
'''
Generate contacts list from the contact arrays calculated
Expand All @@ -606,7 +607,7 @@ def _get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nod
nresidues x nresidues array of CSU stabilising contacts in the molecule
destabilisers: ndarray
nresidues x nresidues array of CSU destabilising contacts in the molecule
nodes: list
res_idx: list
list of serial residue ids for each of the residues
G: nx.Graph
residue based graph of the molecule
Expand All @@ -623,8 +624,8 @@ def _get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nod
rcsu = (stab - dest) > 0

if (over > 0 or cont > 0):
a = np.where(nodes == i1)[0][0]
b = np.where(nodes == i2)[0][0]
a = np.where(res_idx == i1)[0][0]
b = np.where(res_idx == i2)[0][0]
all_contacts.append([i1+1, i2+1, a, b, over, cont, stab, rcsu])
if over == 1 or (over == 0 and rcsu):
# this is a OV or rCSU contact we take it
Expand All @@ -634,10 +635,11 @@ def _get_contacts(nresidues, overlaps, contacts, stabilisers, destabilisers, nod
return contacts_list, all_contacts


def _write_contacts(all_contacts, ca_pos, G):
def _write_contacts(fout, all_contacts, ca_pos, G):
'''
write the contacts calculated to file
fout: str
path to write file to
all_contacts: list
list of lists of every contact found
ca_pos: list
Expand All @@ -661,7 +663,7 @@ def _write_contacts(all_contacts, ca_pos, G):
f"{int(contact[7]): 6d} {int(contact[5]): 6d}\n")
msgs.append(msg)
message_out = ''.join(msgs)
with deferred_open('contact_map_vermouth.out', "w") as f:
with deferred_open(fout, "w") as f:
f.write(message_out)


Expand Down Expand Up @@ -718,7 +720,7 @@ def do_contacts(molecule, write_file):
write_file: bool
write the file of the contacts out
'''
vdw_list, atypes, coords, res_serial, resids, chains, resnames, nodes, ca_pos, nresidues, mol_graph = _contact_info(
vdw_list, atypes, coords, res_serial, resids, chains, resnames, res_idx, ca_pos, nresidues, mol_graph = _contact_info(
molecule)

overlaps, contacts, stabilisers, destabilisers = _calculate_contacts(vdw_list,
Expand All @@ -731,11 +733,11 @@ def do_contacts(molecule, write_file):
overlaps, contacts,
stabilisers,
destabilisers,
nodes,
res_idx,
mol_graph)

if write_file:
_write_contacts(all_contacts, ca_pos, mol_graph)
if isinstance(write_file, str):
_write_contacts(write_file, all_contacts, ca_pos, mol_graph)

return contacts

Expand All @@ -762,4 +764,3 @@ def run_system(self, system):
for molecule in system.molecules:
contacts = self.run_molecule(molecule)
system.go_params["go_map"].append(contacts)

0 comments on commit 7880986

Please sign in to comment.