Skip to content

Commit

Permalink
reset index after handling altlocs #384
Browse files Browse the repository at this point in the history
  • Loading branch information
Arian Jamasb committed Apr 23, 2024
1 parent f8b5ef5 commit 1575ecf
Showing 1 changed file with 24 additions and 73 deletions.
97 changes: 24 additions & 73 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,12 @@ def read_pdb_to_dataframe(
:rtype: pd.DataFrame
"""
if pdb_code is None and path is None and uniprot_id is None:
raise NameError(
"One of pdb_code, path or uniprot_id must be specified!"
)
raise NameError("One of pdb_code, path or uniprot_id must be specified!")

if path is not None:
if isinstance(path, Path):
path = os.fsdecode(path)
if (
path.endswith(".pdb")
or path.endswith(".pdb.gz")
or path.endswith(".ent")
):
if path.endswith(".pdb") or path.endswith(".pdb.gz") or path.endswith(".ent"):
atomic_df = PandasPdb().read_pdb(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
Expand All @@ -116,9 +110,7 @@ def read_pdb_to_dataframe(
f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
atomic_df = PandasPdb().fetch_pdb(uniprot_id=uniprot_id, source="alphafold2-v3")
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)

Expand Down Expand Up @@ -172,11 +164,7 @@ def label_node_id(
df["node_id"] = df["node_id"] + ":" + df["atom_name"]
elif granularity in {"rna_atom", "rna_centroid"}:
df["node_id"] = (
df["node_id"]
+ ":"
+ df["atom_number"].apply(str)
+ ":"
+ df["atom_name"]
df["node_id"] + ":" + df["atom_number"].apply(str) + ":" + df["atom_name"]
)
return df

Expand All @@ -189,9 +177,7 @@ def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame:
:returns: Atomic dataframe with all ``element_symbol == "H" or "D" or "T"`` removed.
:rtype: pd.DataFrame
"""
log.debug(
"Deprotonating protein. This removes H atoms from the pdb_df dataframe"
)
log.debug("Deprotonating protein. This removes H atoms from the pdb_df dataframe")
return filter_dataframe(
df,
by_column="element_symbol",
Expand Down Expand Up @@ -225,9 +211,7 @@ def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame:
return df


def subset_structure_to_atom_type(
df: pd.DataFrame, granularity: str
) -> pd.DataFrame:
def subset_structure_to_atom_type(df: pd.DataFrame, granularity: str) -> pd.DataFrame:
"""
Return a subset of atomic dataframe that contains only certain atom names.
Expand All @@ -241,9 +225,7 @@ def subset_structure_to_atom_type(
)


def remove_alt_locs(
df: pd.DataFrame, keep: str = "max_occupancy"
) -> pd.DataFrame:
def remove_alt_locs(df: pd.DataFrame, keep: str = "max_occupancy") -> pd.DataFrame:
"""
This function removes alternatively located atoms from PDB DataFrames
(see https://proteopedia.org/wiki/index.php/Alternate_locations). Among the
Expand Down Expand Up @@ -277,7 +259,7 @@ def remove_alt_locs(
# Unsort
if keep in ["max_occupancy", "min_occupancy"]:
df = df.sort_index()

df = df.reset_index(drop=True)
return df


Expand Down Expand Up @@ -307,9 +289,7 @@ def remove_insertions(
)


def filter_hetatms(
df: pd.DataFrame, keep_hets: List[str]
) -> List[pd.DataFrame]:
def filter_hetatms(df: pd.DataFrame, keep_hets: List[str]) -> List[pd.DataFrame]:
"""Return hetatms of interest.
:param df: Protein Structure dataframe to filter hetatoms from.
Expand Down Expand Up @@ -454,9 +434,7 @@ def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame:
:return: Sorted protein dataframe.
:rtype: pd.DataFrame
"""
return df.sort_values(
by=["chain_id", "residue_number", "atom_number", "insertion"]
)
return df.sort_values(by=["chain_id", "residue_number", "atom_number", "insertion"])


def select_chains(
Expand Down Expand Up @@ -558,8 +536,7 @@ def initialise_graph_with_metadata(
elif granularity == "atom":
sequence = (
protein_df.loc[
(protein_df["chain_id"] == c)
& (protein_df["atom_name"] == "CA")
(protein_df["chain_id"] == c) & (protein_df["atom_name"] == "CA")
]["residue_name"]
.apply(three_to_one_with_mods)
.str.cat()
Expand Down Expand Up @@ -610,13 +587,9 @@ def add_nodes_to_graph(
# Set intrinsic node attributes
nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id")
nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name")
nx.set_node_attributes(
G, dict(zip(nodes, residue_number)), "residue_number"
)
nx.set_node_attributes(G, dict(zip(nodes, residue_number)), "residue_number")
nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type")
nx.set_node_attributes(
G, dict(zip(nodes, element_symbol)), "element_symbol"
)
nx.set_node_attributes(G, dict(zip(nodes, element_symbol)), "element_symbol")
nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords")
nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor")

Expand All @@ -642,9 +615,7 @@ def calculate_centroid_positions(
:rtype: pd.DataFrame
"""
centroids = (
atoms.groupby(
["residue_number", "chain_id", "residue_name", "insertion"]
)
atoms.groupby(["residue_number", "chain_id", "residue_name", "insertion"])
.mean(numeric_only=True)[["x_coord", "y_coord", "z_coord"]]
.reset_index()
)
Expand Down Expand Up @@ -902,13 +873,9 @@ def _mp_graph_constructor(
func = partial(construct_graph, config=config)
try:
if source == "pdb_code":
return func(
pdb_code=args[0], chain_selection=args[1], model_index=args[2]
)
return func(pdb_code=args[0], chain_selection=args[1], model_index=args[2])
elif source == "path":
return func(
path=args[0], chain_selection=args[1], model_index=args[2]
)
return func(path=args[0], chain_selection=args[1], model_index=args[2])
elif source == "uniprot_id":
return func(
uniprot_id=args[0],
Expand Down Expand Up @@ -1004,9 +971,7 @@ def construct_graphs_mp(
)
if out_path is not None:
[
nx.write_gpickle(
g, str(f"{out_path}/" + f"{g.graph['name']}.pickle")
)
nx.write_gpickle(g, str(f"{out_path}/" + f"{g.graph['name']}.pickle"))
for g in graphs
]

Expand Down Expand Up @@ -1070,15 +1035,11 @@ def compute_chain_graph(

# Add edges
for u, v, d in g.edges(data=True):
h.add_edge(
g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"]
)
h.add_edge(g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"])
# Remove self-loops if necessary. Checks for equality between nodes in a
# given edge.
if remove_self_loops:
edges_to_remove: List[Tuple[str]] = [
(u, v) for u, v in h.edges() if u == v
]
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
h.remove_edges_from(edges_to_remove)

# Compute a weighted graph if required.
Expand Down Expand Up @@ -1181,16 +1142,10 @@ def compute_secondary_structure_graph(
ss_list = ss_list[~ss_list.str.contains("-")]
# Subset to only allowable SS elements if necessary
if allowable_ss_elements:
ss_list = ss_list[
ss_list.str.contains("|".join(allowable_ss_elements))
]
ss_list = ss_list[ss_list.str.contains("|".join(allowable_ss_elements))]

constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(
ss_list.values
)
constituent_residues = {
k: list(v) for k, v in constituent_residues.items()
}
constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(ss_list.values)
constituent_residues = {k: list(v) for k, v in constituent_residues.items()}
residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict()

# Add Nodes from secondary structure list
Expand All @@ -1209,9 +1164,7 @@ def compute_secondary_structure_graph(
# Iterate over edges in source graph and add SS-SS edges to new graph.
for u, v, d in g.edges(data=True):
try:
h.add_edge(
ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}"
)
h.add_edge(ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}")
except KeyError as e:
log.debug(
f"Edge {u}-{v} not added to secondary structure graph. \
Expand All @@ -1221,9 +1174,7 @@ def compute_secondary_structure_graph(
# Remove self-loops if necessary.
# Checks for equality between nodes in a given edge.
if remove_self_loops:
edges_to_remove: List[Tuple[str]] = [
(u, v) for u, v in h.edges() if u == v
]
edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v]
h.remove_edges_from(edges_to_remove)

# Create weighted graph from h
Expand Down

0 comments on commit 1575ecf

Please sign in to comment.