diff --git a/graphein/protein/graphs.py b/graphein/protein/graphs.py index dfe237f7..cc4d0df8 100644 --- a/graphein/protein/graphs.py +++ b/graphein/protein/graphs.py @@ -96,18 +96,12 @@ def read_pdb_to_dataframe( :rtype: pd.DataFrame """ if pdb_code is None and path is None and uniprot_id is None: - raise NameError( - "One of pdb_code, path or uniprot_id must be specified!" - ) + raise NameError("One of pdb_code, path or uniprot_id must be specified!") if path is not None: if isinstance(path, Path): path = os.fsdecode(path) - if ( - path.endswith(".pdb") - or path.endswith(".pdb.gz") - or path.endswith(".ent") - ): + if path.endswith(".pdb") or path.endswith(".pdb.gz") or path.endswith(".ent"): atomic_df = PandasPdb().read_pdb(path) elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"): atomic_df = PandasMmtf().read_mmtf(path) @@ -116,9 +110,7 @@ def read_pdb_to_dataframe( f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}" ) elif uniprot_id is not None: - atomic_df = PandasPdb().fetch_pdb( - uniprot_id=uniprot_id, source="alphafold2-v3" - ) + atomic_df = PandasPdb().fetch_pdb(uniprot_id=uniprot_id, source="alphafold2-v3") else: atomic_df = PandasPdb().fetch_pdb(pdb_code) @@ -172,11 +164,7 @@ def label_node_id( df["node_id"] = df["node_id"] + ":" + df["atom_name"] elif granularity in {"rna_atom", "rna_centroid"}: df["node_id"] = ( - df["node_id"] - + ":" - + df["atom_number"].apply(str) - + ":" - + df["atom_name"] + df["node_id"] + ":" + df["atom_number"].apply(str) + ":" + df["atom_name"] ) return df @@ -189,9 +177,7 @@ def deprotonate_structure(df: pd.DataFrame) -> pd.DataFrame: :returns: Atomic dataframe with all ``element_symbol == "H" or "D" or "T"`` removed. :rtype: pd.DataFrame """ - log.debug( - "Deprotonating protein. This removes H atoms from the pdb_df dataframe" - ) + log.debug("Deprotonating protein. This removes H atoms from the pdb_df dataframe") return filter_dataframe( df, by_column="element_symbol", @@ -225,9 +211,7 @@ def convert_structure_to_centroids(df: pd.DataFrame) -> pd.DataFrame: return df -def subset_structure_to_atom_type( - df: pd.DataFrame, granularity: str -) -> pd.DataFrame: +def subset_structure_to_atom_type(df: pd.DataFrame, granularity: str) -> pd.DataFrame: """ Return a subset of atomic dataframe that contains only certain atom names. @@ -241,9 +225,7 @@ def subset_structure_to_atom_type( ) -def remove_alt_locs( - df: pd.DataFrame, keep: str = "max_occupancy" -) -> pd.DataFrame: +def remove_alt_locs(df: pd.DataFrame, keep: str = "max_occupancy") -> pd.DataFrame: """ This function removes alternatively located atoms from PDB DataFrames (see https://proteopedia.org/wiki/index.php/Alternate_locations). Among the @@ -277,7 +259,7 @@ def remove_alt_locs( # Unsort if keep in ["max_occupancy", "min_occupancy"]: df = df.sort_index() - + df = df.reset_index(drop=True) return df @@ -307,9 +289,7 @@ def remove_insertions( ) -def filter_hetatms( - df: pd.DataFrame, keep_hets: List[str] -) -> List[pd.DataFrame]: +def filter_hetatms(df: pd.DataFrame, keep_hets: List[str]) -> List[pd.DataFrame]: """Return hetatms of interest. :param df: Protein Structure dataframe to filter hetatoms from. @@ -454,9 +434,7 @@ def sort_dataframe(df: pd.DataFrame) -> pd.DataFrame: :return: Sorted protein dataframe. :rtype: pd.DataFrame """ - return df.sort_values( - by=["chain_id", "residue_number", "atom_number", "insertion"] - ) + return df.sort_values(by=["chain_id", "residue_number", "atom_number", "insertion"]) def select_chains( @@ -558,8 +536,7 @@ def initialise_graph_with_metadata( elif granularity == "atom": sequence = ( protein_df.loc[ - (protein_df["chain_id"] == c) - & (protein_df["atom_name"] == "CA") + (protein_df["chain_id"] == c) & (protein_df["atom_name"] == "CA") ]["residue_name"] .apply(three_to_one_with_mods) .str.cat() @@ -610,13 +587,9 @@ def add_nodes_to_graph( # Set intrinsic node attributes nx.set_node_attributes(G, dict(zip(nodes, chain_id)), "chain_id") nx.set_node_attributes(G, dict(zip(nodes, residue_name)), "residue_name") - nx.set_node_attributes( - G, dict(zip(nodes, residue_number)), "residue_number" - ) + nx.set_node_attributes(G, dict(zip(nodes, residue_number)), "residue_number") nx.set_node_attributes(G, dict(zip(nodes, atom_type)), "atom_type") - nx.set_node_attributes( - G, dict(zip(nodes, element_symbol)), "element_symbol" - ) + nx.set_node_attributes(G, dict(zip(nodes, element_symbol)), "element_symbol") nx.set_node_attributes(G, dict(zip(nodes, coords)), "coords") nx.set_node_attributes(G, dict(zip(nodes, b_factor)), "b_factor") @@ -642,9 +615,7 @@ def calculate_centroid_positions( :rtype: pd.DataFrame """ centroids = ( - atoms.groupby( - ["residue_number", "chain_id", "residue_name", "insertion"] - ) + atoms.groupby(["residue_number", "chain_id", "residue_name", "insertion"]) .mean(numeric_only=True)[["x_coord", "y_coord", "z_coord"]] .reset_index() ) @@ -902,13 +873,9 @@ def _mp_graph_constructor( func = partial(construct_graph, config=config) try: if source == "pdb_code": - return func( - pdb_code=args[0], chain_selection=args[1], model_index=args[2] - ) + return func(pdb_code=args[0], chain_selection=args[1], model_index=args[2]) elif source == "path": - return func( - path=args[0], chain_selection=args[1], model_index=args[2] - ) + return func(path=args[0], chain_selection=args[1], model_index=args[2]) elif source == "uniprot_id": return func( uniprot_id=args[0], @@ -1004,9 +971,7 @@ def construct_graphs_mp( ) if out_path is not None: [ - nx.write_gpickle( - g, str(f"{out_path}/" + f"{g.graph['name']}.pickle") - ) + nx.write_gpickle(g, str(f"{out_path}/" + f"{g.graph['name']}.pickle")) for g in graphs ] @@ -1070,15 +1035,11 @@ def compute_chain_graph( # Add edges for u, v, d in g.edges(data=True): - h.add_edge( - g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"] - ) + h.add_edge(g.nodes[u]["chain_id"], g.nodes[v]["chain_id"], kind=d["kind"]) # Remove self-loops if necessary. Checks for equality between nodes in a # given edge. if remove_self_loops: - edges_to_remove: List[Tuple[str]] = [ - (u, v) for u, v in h.edges() if u == v - ] + edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v] h.remove_edges_from(edges_to_remove) # Compute a weighted graph if required. @@ -1181,16 +1142,10 @@ def compute_secondary_structure_graph( ss_list = ss_list[~ss_list.str.contains("-")] # Subset to only allowable SS elements if necessary if allowable_ss_elements: - ss_list = ss_list[ - ss_list.str.contains("|".join(allowable_ss_elements)) - ] + ss_list = ss_list[ss_list.str.contains("|".join(allowable_ss_elements))] - constituent_residues: Dict[str, List[str]] = ss_list.index.groupby( - ss_list.values - ) - constituent_residues = { - k: list(v) for k, v in constituent_residues.items() - } + constituent_residues: Dict[str, List[str]] = ss_list.index.groupby(ss_list.values) + constituent_residues = {k: list(v) for k, v in constituent_residues.items()} residue_counts: Dict[str, int] = ss_list.groupby(ss_list).count().to_dict() # Add Nodes from secondary structure list @@ -1209,9 +1164,7 @@ def compute_secondary_structure_graph( # Iterate over edges in source graph and add SS-SS edges to new graph. for u, v, d in g.edges(data=True): try: - h.add_edge( - ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}" - ) + h.add_edge(ss_list[u], ss_list[v], kind=d["kind"], source=f"{u}_{v}") except KeyError as e: log.debug( f"Edge {u}-{v} not added to secondary structure graph. \ @@ -1221,9 +1174,7 @@ def compute_secondary_structure_graph( # Remove self-loops if necessary. # Checks for equality between nodes in a given edge. if remove_self_loops: - edges_to_remove: List[Tuple[str]] = [ - (u, v) for u, v in h.edges() if u == v - ] + edges_to_remove: List[Tuple[str]] = [(u, v) for u, v in h.edges() if u == v] h.remove_edges_from(edges_to_remove) # Create weighted graph from h